aboutsummaryrefslogtreecommitdiff
path: root/src/backends/aclCommon
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/aclCommon')
-rw-r--r--src/backends/aclCommon/ArmComputeTensorUtils.cpp26
-rw-r--r--src/backends/aclCommon/ArmComputeTensorUtils.hpp6
-rw-r--r--src/backends/aclCommon/ArmComputeUtils.hpp26
3 files changed, 31 insertions, 27 deletions
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
index c5b4fa157e..cfd2e0e110 100644
--- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp
+++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp
@@ -459,5 +459,31 @@ unsigned int ComputeDepthwiseConv2dDepthMultiplier(armnn::DataLayout layout,
return depthMultiplier;
}
+arm_compute::ScatterInfo BuildArmComputeScatterInfo(const ScatterNdDescriptor& descriptor)
+{
+ arm_compute::ScatterFunction scatterFunction;
+ switch(descriptor.m_Function)
+ {
+ case ScatterNdFunction::Update:
+ scatterFunction = arm_compute::ScatterFunction::Update;
+ break;
+ case ScatterNdFunction::Add:
+ scatterFunction = arm_compute::ScatterFunction::Add;
+ break;
+ case ScatterNdFunction::Sub:
+ scatterFunction = arm_compute::ScatterFunction::Sub;
+ break;
+ case ScatterNdFunction::Max:
+ scatterFunction = arm_compute::ScatterFunction::Max;
+ break;
+ case ScatterNdFunction::Min:
+ scatterFunction = arm_compute::ScatterFunction::Min;
+ break;
+ default: throw InvalidArgumentException("Unknown ArmNN::ScatterNd Function: [" +
+ std::to_string(static_cast<int>(descriptor.m_Function)) + "]");
+ }
+
+ return arm_compute::ScatterInfo(scatterFunction, !descriptor.m_InputEnabled);
+}
} // namespace armcomputetensorutils
} // namespace armnn
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.hpp b/src/backends/aclCommon/ArmComputeTensorUtils.hpp
index d8a41fe41f..63c70c7092 100644
--- a/src/backends/aclCommon/ArmComputeTensorUtils.hpp
+++ b/src/backends/aclCommon/ArmComputeTensorUtils.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -12,6 +12,7 @@
#include <arm_compute/core/ITensor.h>
#include <arm_compute/core/TensorInfo.h>
#include <arm_compute/core/Types.h>
+#include <arm_compute/function_info/ScatterInfo.h>
#include <Half.hpp>
@@ -108,6 +109,9 @@ unsigned int ComputeDepthwiseConv2dDepthMultiplier(armnn::DataLayout layout,
const arm_compute::TensorShape& weightsShape,
const arm_compute::TensorShape& inputShape);
+/// Utility function used to setup an arm_compute::ScatterInfo from ArmNN ScatterNd descriptor
+arm_compute::ScatterInfo BuildArmComputeScatterInfo(const ScatterNdDescriptor& descriptor);
+
/// Utility function used to setup an arm_compute::PadStrideInfo object from an ArmNN layer descriptor.
template <typename Descriptor>
arm_compute::PadStrideInfo BuildArmComputePadStrideInfo(const Descriptor& descriptor)
diff --git a/src/backends/aclCommon/ArmComputeUtils.hpp b/src/backends/aclCommon/ArmComputeUtils.hpp
index d7025aa5e2..fc77f810ee 100644
--- a/src/backends/aclCommon/ArmComputeUtils.hpp
+++ b/src/backends/aclCommon/ArmComputeUtils.hpp
@@ -242,32 +242,6 @@ inline T ComputeSoftmaxAclAxis(const SoftmaxDescriptor& softmaxDesc, const armnn
return aclAxis;
}
-inline std::set<unsigned int> ComputeSplitAxis(const armnn::SplitterDescriptor& desc, const TensorShape& input)
-{
- unsigned int numSplit = desc.GetNumViews();
- unsigned int numDimensions = desc.GetNumDimensions();
- std::set<unsigned int> splitAxis;
-
- if (desc.HasAxis())
- {
- splitAxis.insert(armnnUtils::GetUnsignedAxis(desc.GetNumDimensions(), desc.GetAxis()));
- }
- else
- {
- for (unsigned int i = 0; i < numSplit; ++i)
- {
- for (unsigned int dimIdx = 0; dimIdx < numDimensions; ++dimIdx)
- {
- if (desc.GetViewSizes(i)[dimIdx] != input[dimIdx])
- {
- splitAxis.insert(dimIdx);
- }
- }
- }
- }
- return splitAxis;
-}
-
/// Function to convert ArmNN axis (left to right) to ACL axis (right to left) ranging from [-rank, rank)
inline int ComputeAclAxis(const int& armnnAxis, const armnn::TensorInfo& tensor)
{