aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt6
-rw-r--r--src/backends/reference/workloads/RefScatterNdWorkload.cpp100
-rw-r--r--src/backends/reference/workloads/RefScatterNdWorkload.hpp30
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp3
-rw-r--r--src/backends/reference/workloads/ScatterNd.cpp336
-rw-r--r--src/backends/reference/workloads/ScatterNd.hpp34
6 files changed, 507 insertions, 2 deletions
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 42f92aec1d..0f70cb0022 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -1,5 +1,5 @@
#
-# Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
+# Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
# SPDX-License-Identifier: MIT
#
@@ -85,6 +85,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefCastWorkload.hpp
RefChannelShuffleWorkload.cpp
RefChannelShuffleWorkload.hpp
+ RefScatterNdWorkload.cpp
+ RefScatterNdWorkload.hpp
RefShapeWorkload.hpp
RefComparisonWorkload.cpp
RefComparisonWorkload.hpp
@@ -195,6 +197,8 @@ list(APPEND armnnRefBackendWorkloads_sources
Resize.cpp
Resize.hpp
Rsqrt.hpp
+ ScatterNd.cpp
+ ScatterNd.hpp
Sin.hpp
Slice.cpp
Slice.hpp
diff --git a/src/backends/reference/workloads/RefScatterNdWorkload.cpp b/src/backends/reference/workloads/RefScatterNdWorkload.cpp
new file mode 100644
index 0000000000..4713add0e9
--- /dev/null
+++ b/src/backends/reference/workloads/RefScatterNdWorkload.cpp
@@ -0,0 +1,100 @@
+//
+// Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <fmt/format.h>
+#include "RefScatterNdWorkload.hpp"
+#include "RefWorkloadUtils.hpp"
+#include "ScatterNd.hpp"
+#include "Profiling.hpp"
+
+namespace armnn
+{
+
+ RefScatterNdWorkload::RefScatterNdWorkload(const ScatterNdQueueDescriptor& descriptor, const WorkloadInfo& info)
+ : RefBaseWorkload(descriptor, info)
+ {}
+
+ void RefScatterNdWorkload::Execute() const
+ {
+ Execute(m_Data.m_Inputs, m_Data.m_Outputs);
+ }
+
+ void RefScatterNdWorkload::ExecuteAsync(ExecutionData& executionData)
+ {
+ WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
+ Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
+ }
+
+ void RefScatterNdWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
+ {
+ ARMNN_SCOPED_PROFILING_EVENT_REF_NAME_GUID("RefScatterNdWorkload_Execute");
+
+ if (m_Data.m_Parameters.m_InputEnabled)
+ {
+ // Getting TensorInfos for three inputs slots
+ const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
+ const TensorInfo& indicesInfo = GetTensorInfo(inputs[1]);
+ const TensorInfo& updatesInfo = GetTensorInfo(inputs[2]);
+
+ // Getting Decoder for input
+ std::unique_ptr<Decoder<float>> inputDecoder = MakeDecoder<float>(GetTensorInfo(inputs[0]),
+ inputs[0]->Map());
+
+ // Getting Decoder for indices
+ std::unique_ptr<Decoder<int>> indicesDecoder = MakeDecoder<int>(GetTensorInfo(inputs[1]),
+ inputs[1]->Map());
+
+ // Getting Decoder for updates
+ std::unique_ptr<Decoder<float>> updatesDecoder = MakeDecoder<float>(GetTensorInfo(inputs[2]),
+ inputs[2]->Map());
+
+ // Getting Encoder for output
+ std::unique_ptr<Encoder<float>> outputEncoder = MakeEncoder<float>(GetTensorInfo(outputs[0]),
+ outputs[0]->Map());
+
+ ScatterNd(inputInfo,
+ indicesInfo,
+ updatesInfo,
+ *inputDecoder,
+ *indicesDecoder,
+ *updatesDecoder,
+ *outputEncoder,
+ m_Data.m_Parameters);
+ }
+ else
+ {
+ // Getting TensorInfos for three inputs slots
+ const TensorInfo& shapeInfo = GetTensorInfo(inputs[0]);
+ const TensorInfo& indicesInfo = GetTensorInfo(inputs[1]);
+ const TensorInfo& updatesInfo = GetTensorInfo(inputs[2]);
+
+ // Getting Decoder for shape
+ std::unique_ptr<Decoder<int>> shapeDecoder = MakeDecoder<int>(GetTensorInfo(inputs[0]),
+ inputs[0]->Map());
+
+ // Getting Decoder for indices
+ std::unique_ptr<Decoder<int>> indicesDecoder = MakeDecoder<int>(GetTensorInfo(inputs[1]),
+ inputs[1]->Map());
+
+ // Getting Decoder for updates
+ std::unique_ptr<Decoder<float>> updatesDecoder = MakeDecoder<float>(GetTensorInfo(inputs[2]),
+ inputs[2]->Map());
+
+ // Getting Encoder for output
+ std::unique_ptr<Encoder<float>> outputEncoder = MakeEncoder<float>(GetTensorInfo(outputs[0]),
+ outputs[0]->Map());
+
+ ScatterNd(indicesInfo,
+ updatesInfo,
+ shapeInfo,
+ *indicesDecoder,
+ *updatesDecoder,
+ *shapeDecoder,
+ *outputEncoder,
+ m_Data.m_Parameters);
+ }
+ }
+
+} // namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/RefScatterNdWorkload.hpp b/src/backends/reference/workloads/RefScatterNdWorkload.hpp
new file mode 100644
index 0000000000..c9cf5a3af3
--- /dev/null
+++ b/src/backends/reference/workloads/RefScatterNdWorkload.hpp
@@ -0,0 +1,30 @@
+//
+// Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "RefBaseWorkload.hpp"
+#include <armnn/backends/WorkloadData.hpp>
+
+#include "ScatterNd.hpp"
+
+namespace armnn
+{
+
+ class RefScatterNdWorkload : public RefBaseWorkload<ScatterNdQueueDescriptor>
+ {
+ public:
+ explicit RefScatterNdWorkload(const ScatterNdQueueDescriptor& descriptor,
+ const WorkloadInfo& info);
+
+ void Execute() const override;
+ void ExecuteAsync(ExecutionData& executionData) override;
+
+ private:
+ void Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const;
+
+ };
+
+} // namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 98aa27b8a9..92b178c3d5 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -55,6 +55,7 @@
#include "RefReshapeWorkload.hpp"
#include "RefResizeWorkload.hpp"
#include "RefReverseV2Workload.hpp"
+#include "RefScatterNdWorkload.hpp"
#include "RefShapeWorkload.hpp"
#include "RefSliceWorkload.hpp"
#include "RefSplitterWorkload.hpp"
diff --git a/src/backends/reference/workloads/ScatterNd.cpp b/src/backends/reference/workloads/ScatterNd.cpp
new file mode 100644
index 0000000000..8eb53b00a8
--- /dev/null
+++ b/src/backends/reference/workloads/ScatterNd.cpp
@@ -0,0 +1,336 @@
+//
+// Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ScatterNd.hpp"
+#include "Encoders.hpp"
+#include <armnn/backends/WorkloadData.hpp>
+#include <armnn/Logging.hpp>
+
+#include <fmt/format.h>
+
+#include <numeric>
+
+namespace armnn
+{
+
+float ScatterOperation(ScatterNdFunction operation,
+ float input,
+ float update)
+{
+ switch (operation)
+ {
+ case ScatterNdFunction::Update:
+ return update;
+ case ScatterNdFunction::Add:
+ return input + update;
+ case ScatterNdFunction::Sub:
+ return input - update;
+ case ScatterNdFunction::Max:
+ return std::max(input, update);
+ case ScatterNdFunction::Min:
+ return std::min(input, update);
+ case ScatterNdFunction::Mul:
+ return input * update;
+ default:
+ throw InvalidArgumentException("ScatterNd: cannot execute this operation.");
+ }
+}
+
+void ScatterNd(const TensorInfo& inputInfo,
+ const TensorInfo& indicesInfo,
+ const TensorInfo& updatesInfo,
+ Decoder<float>& input,
+ Decoder<int>& indices,
+ Decoder<float>& updates,
+ Encoder<float>& output,
+ ScatterNdDescriptor descriptor)
+{
+ // Axis Unsupported
+ if (descriptor.m_AxisEnabled)
+ {
+ throw InvalidArgumentException("ScatterNd: axis param not supported.");
+ }
+
+ // Get the shape for indices, updates, and input
+ TensorShape indicesShape = indicesInfo.GetShape();
+ TensorShape updatesShape = updatesInfo.GetShape();
+ TensorShape inputShape = inputInfo.GetShape();
+
+ // Get the dimensions for indices and updates
+ unsigned int dimension = inputInfo.GetNumDimensions();
+ unsigned int indicesDim = indicesInfo.GetNumDimensions();
+ unsigned int updatesDim = updatesInfo.GetNumDimensions();
+
+ // Calculate the outter and inner dimensions
+ unsigned int outterDim = indicesShape[indicesDim - 1];
+ unsigned int innerDim = dimension - outterDim;
+
+ // Calculate the number of elements in each dimension
+ unsigned int numElementsCount = 1;
+ std::vector<unsigned int> elementInDim(dimension);
+ for (unsigned int dimIndex = dimension; dimIndex > 0; --dimIndex)
+ {
+ elementInDim[dimIndex - 1] = numElementsCount;
+ numElementsCount *= inputShape[dimIndex - 1];
+ }
+
+ // Number of updates per index
+ unsigned int numUpdatesPerIndex = elementInDim[dimension - innerDim - 1];
+
+ // Number of indices to update
+ unsigned int numIndices = indicesShape[0];
+
+ // Check Input Requirements
+ // Requirement 1: Indices and Updates must have rank at least 1
+ if (indicesDim < 1 || updatesDim < 1)
+ {
+ throw InvalidArgumentException("ScatterNd: indices and updates must have rank >= 1.");
+ }
+
+ // Requirement 2: Input, Indices and Updates must have values
+ if (inputInfo.GetNumElements() == 0 ||
+ indicesInfo.GetNumElements() == 0 ||
+ updatesInfo.GetNumElements() == 0)
+ {
+ throw InvalidArgumentException("ScatterNd: input, indices and updates tensor must have values.");
+ }
+
+ // Requirement 3: Indices and Updates must match in shape
+ // The updates dimension should equals to 1 + inner dimension
+ if (updatesDim != 1 + innerDim)
+ {
+ throw InvalidArgumentException("ScatterNd: updates dimension should equal to 1 + inner dimension.");
+ }
+ // The inner dimension of updates has to match with shape of input
+ for (unsigned int dimBackIndex = 0; dimBackIndex < innerDim; ++dimBackIndex)
+ {
+ if (updatesShape[updatesDim - dimBackIndex - 1] != inputShape[dimension - dimBackIndex - 1])
+ {
+ throw InvalidArgumentException(
+ fmt::format("ScatterNd: input and updates shape not match on dimension {}",
+ dimension - dimBackIndex));
+ }
+ }
+
+ // Requirement 4: Check duplicate indices and out of bound indices
+ std::set<int> indicesSet;
+ std::vector<int> flattenIndices(numIndices);
+ for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx)
+ {
+ // Get the index
+ int flattenIndex = 0;
+
+ for (unsigned int outterIdx = 0; outterIdx < outterDim; ++outterIdx) {
+
+ int outterIndexValue = indices.Get();
+
+ // Check bounds
+ if (outterIndexValue < 0 || outterIndexValue >= int(inputShape[outterIdx]))
+ {
+ throw InvalidArgumentException(
+ fmt::format("ScatterNd: indices {} out of bound [0, {})",
+ outterIndexValue, inputShape[outterIdx]));
+ }
+
+ flattenIndex += int(elementInDim[outterIdx]) * outterIndexValue;
+ ++indices;
+ }
+
+ // Check duplicates when executing ScatterNd::Update
+ if (descriptor.m_Function == ScatterNdFunction::Update &&
+ indicesSet.find(flattenIndex) != indicesSet.end())
+ {
+ throw InvalidArgumentException(
+ fmt::format("ScatterNd: duplicate indices occurs {}", flattenIndex));
+ }
+
+ flattenIndices[indicesIdx] = flattenIndex;
+ indicesSet.insert(flattenIndex);
+ }
+
+ // Set the input data to output
+ for (unsigned int idx = 0; idx < inputInfo.GetNumElements(); ++idx)
+ {
+ float inputValue = input.Get();
+ ++input;
+ output.Set(inputValue);
+ ++output;
+ }
+
+ // Iterate through all indices to scatter updates
+ for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx)
+ {
+ // Get the index and calculate the flatten index
+ int flattenIndex = flattenIndices[indicesIdx];
+
+ // FlattenIndex is the place that we are going to update the elements
+ unsigned int updatesStartIdx = indicesIdx * numUpdatesPerIndex;
+ for (unsigned int updatesIdx = 0; updatesIdx < numUpdatesPerIndex; ++updatesIdx)
+ {
+ updates[updatesStartIdx + updatesIdx];
+ input[static_cast<unsigned int>(flattenIndex) + updatesIdx];
+ float updateValue = ScatterOperation(descriptor.m_Function, input.Get(), updates.Get());
+ output[static_cast<unsigned int>(flattenIndex) + updatesIdx];
+ output.Set(updateValue);
+ }
+ }
+}
+
+void ScatterNd(const TensorInfo& indicesInfo,
+ const TensorInfo& updatesInfo,
+ const TensorInfo& shapeInfo,
+ Decoder<int>& indices,
+ Decoder<float>& updates,
+ Decoder<int>& shape,
+ Encoder<float>& output,
+ ScatterNdDescriptor descriptor)
+{
+ // Axis Unsupported
+ if (descriptor.m_AxisEnabled)
+ {
+ throw InvalidArgumentException("ScatterNd: axis param not supported.");
+ }
+
+ // Get the shape for indices, updates, and input
+ TensorShape indicesShape = indicesInfo.GetShape();
+ TensorShape updatesShape = updatesInfo.GetShape();
+
+ // Get the shape values
+ std::vector<float> shapeValues = shape.DecodeTensor(shapeInfo.GetShape());
+ // Check the shape
+ if (shapeInfo.GetNumElements() == 0)
+ {
+ throw InvalidArgumentException("ScatterNd: shape must have values.");
+ }
+ for (auto shapeValue : shapeValues)
+ {
+ if (shapeValue <= 0)
+ {
+ throw InvalidArgumentException("ScatterNd: shape values must >= 0.");
+ }
+ }
+ // Get the input shape
+ std::vector<unsigned int> inputShape (shapeValues.begin(), shapeValues.end());
+ unsigned int inputElementsNum = static_cast<unsigned int>(
+ std::accumulate(inputShape.begin(), inputShape.end(), 1, std::multiplies<unsigned int>()));
+
+ // Get the dimensions for indices and updates
+ unsigned int dimension = shapeInfo.GetNumElements();
+ unsigned int indicesDim = indicesInfo.GetNumDimensions();
+ unsigned int updatesDim = updatesInfo.GetNumDimensions();
+
+ // Calculate the outter and inner dimensions
+ unsigned int outterDim = indicesShape[indicesDim - 1];
+ unsigned int innerDim = dimension - outterDim;
+
+ // Calculate the number of elements in each dimension
+ unsigned int numElementsCount = 1;
+ std::vector<unsigned int> elementInDim(dimension);
+ for (unsigned int dimIndex = dimension; dimIndex > 0; --dimIndex)
+ {
+ elementInDim[dimIndex - 1] = numElementsCount;
+ numElementsCount *= inputShape[dimIndex - 1];
+ }
+
+ // Number of updates per index
+ unsigned int numUpdatesPerIndex = elementInDim[dimension - innerDim - 1];
+
+ // Number of indices to update
+ unsigned int numIndices = indicesShape[0];
+
+ // Check Input Requirements
+ // Requirement 1: Indices and Updates must have rank at least 1
+ if (indicesDim < 1 || updatesDim < 1)
+ {
+ throw InvalidArgumentException("ScatterNd: indices and updates must have rank >= 1.");
+ }
+
+ // Requirement 2: shape, Indices and Updates must have values
+ if (indicesInfo.GetNumElements() == 0 ||
+ updatesInfo.GetNumElements() == 0)
+ {
+ throw InvalidArgumentException("ScatterNd: indices and updates tensor must have values.");
+ }
+
+ // Requirement 3: Indices and Updates must match in shape
+ // The updates dimension should equals to 1 + inner dimension
+ if (updatesDim != 1 + innerDim)
+ {
+ throw InvalidArgumentException("ScatterNd: updates dimension should equal to 1 + inner dimension.");
+ }
+ // The inner dimension of updates has to match with shape of input
+ for (unsigned int dimBackIndex = 0; dimBackIndex < innerDim; ++dimBackIndex)
+ {
+ if (updatesShape[updatesDim - dimBackIndex - 1] != inputShape[dimension - dimBackIndex - 1])
+ {
+ throw InvalidArgumentException(
+ fmt::format("ScatterNd: input and updates shape not match on dimension {}",
+ dimension - dimBackIndex));
+ }
+ }
+
+ // Requirement 4: Check duplicate indices and out of bound indices
+ std::set<int> indicesSet;
+ std::vector<int> flattenIndices(numIndices);
+ for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx)
+ {
+ // Get the index
+ int flattenIndex = 0;
+
+ for (unsigned int outterIdx = 0; outterIdx < outterDim; ++outterIdx) {
+
+ int outterIndexValue = indices.Get();
+
+ // Check bounds
+ if (outterIndexValue < 0 || outterIndexValue >= int(inputShape[outterIdx]))
+ {
+ throw InvalidArgumentException(
+ fmt::format("ScatterNd: indices {} out of bound [0, {})",
+ outterIndexValue, inputShape[outterIdx]));
+ }
+
+ flattenIndex += int(elementInDim[outterIdx]) * outterIndexValue;
+ ++indices;
+ }
+
+ // Check duplicates when executing ScatterNd::Update
+ if (descriptor.m_Function == ScatterNdFunction::Update &&
+ indicesSet.find(flattenIndex) != indicesSet.end())
+ {
+ throw InvalidArgumentException(
+ fmt::format("ScatterNd: duplicate indices {} occurs when executing ScatterNd::Update.",
+ flattenIndex));
+ }
+
+ flattenIndices[indicesIdx] = flattenIndex;
+ indicesSet.insert(flattenIndex);
+ }
+
+ // Set zeros to output
+ for (unsigned int idx = 0; idx < inputElementsNum; ++idx)
+ {
+ output.Set(0.0f);
+ ++output;
+ }
+
+ // Iterate through all indices to scatter updates
+ for (unsigned int indicesIdx = 0; indicesIdx < numIndices; ++indicesIdx)
+ {
+ // Get the index and calculate the flatten index
+ int flattenIndex = flattenIndices[indicesIdx];
+
+ // FlattenIndex is the place that we are going to update the elements
+ unsigned int updatesStartIdx = indicesIdx * numUpdatesPerIndex;
+ for (unsigned int updatesIdx = 0; updatesIdx < numUpdatesPerIndex; ++updatesIdx)
+ {
+ updates[updatesStartIdx + updatesIdx];
+ float updateValue = ScatterOperation(descriptor.m_Function, 0.0f, updates.Get());
+ output[static_cast<unsigned int>(flattenIndex) + updatesIdx];
+ output.Set(updateValue);
+ }
+ }
+}
+
+} // namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/ScatterNd.hpp b/src/backends/reference/workloads/ScatterNd.hpp
new file mode 100644
index 0000000000..e40d3640a7
--- /dev/null
+++ b/src/backends/reference/workloads/ScatterNd.hpp
@@ -0,0 +1,34 @@
+//
+// Copyright © 2024 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/Tensor.hpp>
+#include "Encoders.hpp"
+#include "Decoders.hpp"
+#include "armnn/Descriptors.hpp"
+
+namespace armnn
+{
+// ScatterNd with input tensor
+void ScatterNd(const TensorInfo& inputInfo,
+ const TensorInfo& indicesInfo,
+ const TensorInfo& updatesInfo,
+ Decoder<float>& input,
+ Decoder<int>& indices,
+ Decoder<float>& updates,
+ Encoder<float>& output,
+ ScatterNdDescriptor descriptor);
+
+// ScatterNd without input tensor, only shape provided
+void ScatterNd(const TensorInfo& indicesInfo,
+ const TensorInfo& updatesInfo,
+ const TensorInfo& shapeInfo,
+ Decoder<int>& indices,
+ Decoder<float>& updates,
+ Decoder<int>& shape,
+ Encoder<float>& output,
+ ScatterNdDescriptor descriptor);
+} // namespace armnn \ No newline at end of file