aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorTianle Cheng <tianle.cheng@arm.com>2024-02-23 17:56:54 +0000
committerKevin May <kevin.may@arm.com>2024-02-28 16:12:34 +0000
commit282881877522d3e94752dfc0839de9bfa0aa5a81 (patch)
tree9cd11c96eb4c179e76f2e586d5a9d9b416dd85a0 /include
parent2883a86c5a167aea3c736529bff5921ab6cbc99c (diff)
downloadarmnn-282881877522d3e94752dfc0839de9bfa0aa5a81.tar.gz
IVGCVSW-8229 & IVGCVSW-8237 ScatterNd: Front end and reference implementation
(scatter_nd, scatter_nd_add, and scatter_nd_update, scatter_nd_sub, scatter_nd_min, scatter_nd_max, scatter_nd_mul) * Front end support for ScatterNd added. * Reference implementation for ScatterNd added. * Unit tests added. Signed-off-by: Tianle Cheng <tianle.cheng@arm.com> Change-Id: I30da9056d9b03ca9b5fb8d09987341128badbcf4
Diffstat (limited to 'include')
-rw-r--r--include/armnn/BackendHelper.hpp9
-rw-r--r--include/armnn/Descriptors.hpp55
-rw-r--r--include/armnn/DescriptorsFwd.hpp3
-rw-r--r--include/armnn/INetwork.hpp9
-rw-r--r--include/armnn/Types.hpp18
-rw-r--r--include/armnn/backends/WorkloadData.hpp7
6 files changed, 92 insertions, 9 deletions
diff --git a/include/armnn/BackendHelper.hpp b/include/armnn/BackendHelper.hpp
index b61f010b0f..a6b81eaa01 100644
--- a/include/armnn/BackendHelper.hpp
+++ b/include/armnn/BackendHelper.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017-2019,2021-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2019,2021-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -380,6 +380,13 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional());
+ bool IsScatterNdSupported(const TensorInfo& input,
+ const TensorInfo& indices,
+ const TensorInfo& updates,
+ const TensorInfo& output,
+ const ScatterNdDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional());
+
bool IsSliceSupported(const TensorInfo& input,
const TensorInfo& output,
const SliceDescriptor& descriptor,
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index bf40b35ae9..7230bc2c1d 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -1675,4 +1675,57 @@ struct BroadcastToDescriptor : BaseDescriptor
TensorShape m_BroadcastToShape;
};
+/// A ScatterNdDescriptor for the ScatterNdLayer.
+struct ScatterNdDescriptor : BaseDescriptor
+{
+ // default constructor
+ ScatterNdDescriptor()
+ : m_Function(ScatterNdFunction::Update)
+ , m_InputEnabled(true)
+ , m_Axis(0)
+ , m_AxisEnabled(false)
+ {}
+
+ // constructor for operators except for ScatterElement operator
+ ScatterNdDescriptor(ScatterNdFunction function,
+ bool inputEnabled)
+ : m_Function(function)
+ , m_InputEnabled(inputEnabled)
+ , m_Axis(0)
+ , m_AxisEnabled(false)
+
+ {}
+
+ // constructor for ScatterElement operator
+ ScatterNdDescriptor(ScatterNdFunction function,
+ bool inputEnabled,
+ int32_t axis)
+ : m_Function(function)
+ , m_InputEnabled(inputEnabled)
+ , m_Axis(axis)
+ , m_AxisEnabled(true)
+
+ {}
+
+ bool operator ==(const ScatterNdDescriptor &rhs) const
+ {
+ return ((m_Function == rhs.m_Function) &&
+ (m_InputEnabled == rhs.m_InputEnabled) &&
+ (m_AxisEnabled == rhs.m_AxisEnabled) &&
+ (m_Axis == rhs.m_Axis));
+ }
+
+ /// Specify if the function is update, add, sub, max or min.
+ ScatterNdFunction m_Function;
+
+ /// Flag to show if input tensor is accepted.
+ bool m_InputEnabled;
+
+ /// Extra attribute for ScatterElement, will be set to 0 by default, we do not support axis != 0
+ int32_t m_Axis;
+
+ /// Flag for ScatterElement, will be set to false by default, we do not support m_AxisEnable = true for now.
+ bool m_AxisEnabled;
+};
+
} // namespace armnn
diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp
index 4b0b70c2d3..3518a41c42 100644
--- a/include/armnn/DescriptorsFwd.hpp
+++ b/include/armnn/DescriptorsFwd.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
@@ -44,6 +44,7 @@ struct QLstmDescriptor;
struct ReshapeDescriptor;
struct ResizeDescriptor;
struct ReduceDescriptor;
+struct ScatterNdDescriptor;
struct SliceDescriptor;
struct SoftmaxDescriptor;
struct SpaceToBatchNdDescriptor;
diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp
index 64fdab6bd0..84f3e0cb64 100644
--- a/include/armnn/INetwork.hpp
+++ b/include/armnn/INetwork.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -864,6 +864,13 @@ public:
IConnectableLayer* AddBroadcastToLayer(const BroadcastToDescriptor& descriptor,
const char* name = nullptr);
+ /// Add a ScatterNd layer to the network
+ /// @param descriptor - Parameters for the ScatterNd operation
+ /// @param name - Optional name for the layer
+ /// @return - Interface for configuring the layer
+ IConnectableLayer* AddScatterNdLayer(const ScatterNdDescriptor& descriptor,
+ const char* name = nullptr);
+
void ExecuteStrategy(IStrategy& strategy) const;
protected:
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp
index d87e7f7147..bbe1ecccbd 100644
--- a/include/armnn/Types.hpp
+++ b/include/armnn/Types.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2018-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2018-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -482,8 +482,8 @@ using InferenceTimingPair = std::pair<HighResolutionClock, HighResolutionClock>;
X(ReverseV2) \
X(Tile) \
X(Fused) \
- X(BroadcastTo) \
-
+ X(BroadcastTo) \
+ X(ScatterNd) \
// New layers should be added at last position to minimize instability.
/// When adding a new layer, adapt also the LastLayer enum value in the
@@ -494,7 +494,17 @@ enum class LayerType
LIST_OF_LAYER_TYPE
#undef X
FirstLayer = Activation,
- LastLayer = BroadcastTo
+ LastLayer = ScatterNd
+};
+
+enum class ScatterNdFunction
+{
+ Update = 0,
+ Add = 1,
+ Sub = 2,
+ Max = 3,
+ Min = 4,
+ Mul = 5
};
const char* GetLayerTypeAsCString(LayerType type);
diff --git a/include/armnn/backends/WorkloadData.hpp b/include/armnn/backends/WorkloadData.hpp
index a90a1abd65..a93d986e4d 100644
--- a/include/armnn/backends/WorkloadData.hpp
+++ b/include/armnn/backends/WorkloadData.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2021-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2021-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -770,4 +770,9 @@ struct BroadcastToQueueDescriptor : QueueDescriptorWithParameters<BroadcastToDes
void Validate(const WorkloadInfo& workloadInfo) const;
};
+struct ScatterNdQueueDescriptor : QueueDescriptorWithParameters<ScatterNdDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
} // namespace armnn