// // Copyright © 2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #include "ScatterNd.hpp" #include "Encoders.hpp" #include #include #include #include namespace armnn { float ScatterOperation(ScatterNdFunction operation, float input, float update) { switch (operation) { case ScatterNdFunction::Update: return update; case ScatterNdFunction::Add: return input + update; case ScatterNdFunction::Sub: return input - update; case ScatterNdFunction::Max: return std::max(input, update); case ScatterNdFunction::Min: return std::min(input, update); case ScatterNdFunction::Mul: return input * update; default: throw InvalidArgumentException("ScatterNd: cannot execute this operation."); } } void ScatterNd(const TensorInfo& inputInfo, const TensorInfo& indicesInfo, const TensorInfo& updatesInfo, Decoder& input, Decoder& indices, Decoder& updates, Encoder& output, ScatterNdDescriptor descriptor) { // Axis Unsupported if (descriptor.m_AxisEnabled) { throw InvalidArgumentException("ScatterNd: axis param not supported."); } // Get the shape for indices, updates, and input TensorShape indicesShape = indicesInfo.GetShape(); TensorShape updatesShape = updatesInfo.GetShape(); TensorShape inputShape = inputInfo.GetShape(); // Get the dimensions for indices and updates unsigned int dimension = inputInfo.GetNumDimensions(); unsigned int indicesDim = indicesInfo.GetNumDimensions(); unsigned int updatesDim = updatesInfo.GetNumDimensions(); // Calculate the outter and inner dimensions unsigned int outterDim = indicesShape[indicesDim - 1]; unsigned int innerDim = dimension - outterDim; // Calculate the number of elements in each dimension unsigned int numElementsCount = 1; std::vector elementInDim(dimension); for (unsigned int dimIndex = dimension; dimIndex > 0; --dimIndex) { elementInDim[dimIndex - 1] = numElementsCount; numElementsCount *= inputShape[dimIndex - 1]; } // Number of updates per index unsigned int numUpdatesPerIndex = elementInDim[dimension - innerDim - 1]; // Number of indices to update unsigned int numIndices = indicesShape[0]; // Check Input Requirements // Requirement 1: Indices and Updates must have rank at least 1 if (indicesDim < 1 || updatesDim < 1) { throw InvalidArgumentException("ScatterNd: indices and updates must have rank >= 1."); } // Requirement 2: Input, Indices and Updates must have values if (inputInfo.GetNumElements() == 0 || indicesInfo.GetNumElements() == 0 || updatesInfo.GetNumElements() == 0) { throw InvalidArgumentException("ScatterNd: input, indices and updates tensor must have values."); } // Requirement 3: Indices and Updates must match in shape // The updates dimension should equals to 1 + inner dimension if (updatesDim != 1 + innerDim) { throw InvalidArgumentException("ScatterNd: updates dimension should equal to 1 + inner dimension."); } // The inner dimension of updates has to match with shape of input for (unsigned int dimBackIndex = 0; dimBackIndex < innerDim; ++dimBackIndex) { if (updatesShape[updatesDim - dimBackIndex - 1] != inputShape[dimension - dimBackIndex - 1]) { throw InvalidArgumentException( fmt::format("ScatterNd: input and updates shape not match on dimension {}", dimension - dimBackIndex)); } } // Requirement 4: Check duplicate indices and out of bound indices std::set indicesSet; std::vector flattenIndices(numIndices); for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx) { // Get the index int flattenIndex = 0; for (unsigned int outterIdx = 0; outterIdx < outterDim; ++outterIdx) { int outterIndexValue = indices.Get(); // Check bounds if (outterIndexValue < 0 || outterIndexValue >= int(inputShape[outterIdx])) { throw InvalidArgumentException( fmt::format("ScatterNd: indices {} out of bound [0, {})", outterIndexValue, inputShape[outterIdx])); } flattenIndex += int(elementInDim[outterIdx]) * outterIndexValue; ++indices; } // Check duplicates when executing ScatterNd::Update if (descriptor.m_Function == ScatterNdFunction::Update && indicesSet.find(flattenIndex) != indicesSet.end()) { throw InvalidArgumentException( fmt::format("ScatterNd: duplicate indices occurs {}", flattenIndex)); } flattenIndices[indicesIdx] = flattenIndex; indicesSet.insert(flattenIndex); } // Set the input data to output for (unsigned int idx = 0; idx < inputInfo.GetNumElements(); ++idx) { float inputValue = input.Get(); ++input; output.Set(inputValue); ++output; } // Iterate through all indices to scatter updates for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx) { // Get the index and calculate the flatten index int flattenIndex = flattenIndices[indicesIdx]; // FlattenIndex is the place that we are going to update the elements unsigned int updatesStartIdx = indicesIdx * numUpdatesPerIndex; for (unsigned int updatesIdx = 0; updatesIdx < numUpdatesPerIndex; ++updatesIdx) { updates[updatesStartIdx + updatesIdx]; input[static_cast(flattenIndex) + updatesIdx]; float updateValue = ScatterOperation(descriptor.m_Function, input.Get(), updates.Get()); output[static_cast(flattenIndex) + updatesIdx]; output.Set(updateValue); } } } void ScatterNd(const TensorInfo& indicesInfo, const TensorInfo& updatesInfo, const TensorInfo& shapeInfo, Decoder& indices, Decoder& updates, Decoder& shape, Encoder& output, ScatterNdDescriptor descriptor) { // Axis Unsupported if (descriptor.m_AxisEnabled) { throw InvalidArgumentException("ScatterNd: axis param not supported."); } // Get the shape for indices, updates, and input TensorShape indicesShape = indicesInfo.GetShape(); TensorShape updatesShape = updatesInfo.GetShape(); // Get the shape values std::vector shapeValues = shape.DecodeTensor(shapeInfo.GetShape()); // Check the shape if (shapeInfo.GetNumElements() == 0) { throw InvalidArgumentException("ScatterNd: shape must have values."); } for (auto shapeValue : shapeValues) { if (shapeValue <= 0) { throw InvalidArgumentException("ScatterNd: shape values must >= 0."); } } // Get the input shape std::vector inputShape (shapeValues.begin(), shapeValues.end()); unsigned int inputElementsNum = static_cast( std::accumulate(inputShape.begin(), inputShape.end(), 1, std::multiplies())); // Get the dimensions for indices and updates unsigned int dimension = shapeInfo.GetNumElements(); unsigned int indicesDim = indicesInfo.GetNumDimensions(); unsigned int updatesDim = updatesInfo.GetNumDimensions(); // Calculate the outter and inner dimensions unsigned int outterDim = indicesShape[indicesDim - 1]; unsigned int innerDim = dimension - outterDim; // Calculate the number of elements in each dimension unsigned int numElementsCount = 1; std::vector elementInDim(dimension); for (unsigned int dimIndex = dimension; dimIndex > 0; --dimIndex) { elementInDim[dimIndex - 1] = numElementsCount; numElementsCount *= inputShape[dimIndex - 1]; } // Number of updates per index unsigned int numUpdatesPerIndex = elementInDim[dimension - innerDim - 1]; // Number of indices to update unsigned int numIndices = indicesShape[0]; // Check Input Requirements // Requirement 1: Indices and Updates must have rank at least 1 if (indicesDim < 1 || updatesDim < 1) { throw InvalidArgumentException("ScatterNd: indices and updates must have rank >= 1."); } // Requirement 2: shape, Indices and Updates must have values if (indicesInfo.GetNumElements() == 0 || updatesInfo.GetNumElements() == 0) { throw InvalidArgumentException("ScatterNd: indices and updates tensor must have values."); } // Requirement 3: Indices and Updates must match in shape // The updates dimension should equals to 1 + inner dimension if (updatesDim != 1 + innerDim) { throw InvalidArgumentException("ScatterNd: updates dimension should equal to 1 + inner dimension."); } // The inner dimension of updates has to match with shape of input for (unsigned int dimBackIndex = 0; dimBackIndex < innerDim; ++dimBackIndex) { if (updatesShape[updatesDim - dimBackIndex - 1] != inputShape[dimension - dimBackIndex - 1]) { throw InvalidArgumentException( fmt::format("ScatterNd: input and updates shape not match on dimension {}", dimension - dimBackIndex)); } } // Requirement 4: Check duplicate indices and out of bound indices std::set indicesSet; std::vector flattenIndices(numIndices); for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx) { // Get the index int flattenIndex = 0; for (unsigned int outterIdx = 0; outterIdx < outterDim; ++outterIdx) { int outterIndexValue = indices.Get(); // Check bounds if (outterIndexValue < 0 || outterIndexValue >= int(inputShape[outterIdx])) { throw InvalidArgumentException( fmt::format("ScatterNd: indices {} out of bound [0, {})", outterIndexValue, inputShape[outterIdx])); } flattenIndex += int(elementInDim[outterIdx]) * outterIndexValue; ++indices; } // Check duplicates when executing ScatterNd::Update if (descriptor.m_Function == ScatterNdFunction::Update && indicesSet.find(flattenIndex) != indicesSet.end()) { throw InvalidArgumentException( fmt::format("ScatterNd: duplicate indices {} occurs when executing ScatterNd::Update.", flattenIndex)); } flattenIndices[indicesIdx] = flattenIndex; indicesSet.insert(flattenIndex); } // Set zeros to output for (unsigned int idx = 0; idx < inputElementsNum; ++idx) { output.Set(0.0f); ++output; } // Iterate through all indices to scatter updates for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx) { // Get the index and calculate the flatten index int flattenIndex = flattenIndices[indicesIdx]; // FlattenIndex is the place that we are going to update the elements unsigned int updatesStartIdx = indicesIdx * numUpdatesPerIndex; for (unsigned int updatesIdx = 0; updatesIdx < numUpdatesPerIndex; ++updatesIdx) { updates[updatesStartIdx + updatesIdx]; float updateValue = ScatterOperation(descriptor.m_Function, 0.0f, updates.Get()); output[static_cast(flattenIndex) + updatesIdx]; output.Set(updateValue); } } } } // namespace armnn