aboutsummaryrefslogtreecommitdiff
path: root/include/armnn/Descriptors.hpp
diff options
context:
space:
mode:
authorTianle Cheng <tianle.cheng@arm.com>2024-02-23 17:56:54 +0000
committerKevin May <kevin.may@arm.com>2024-02-28 16:12:34 +0000
commit282881877522d3e94752dfc0839de9bfa0aa5a81 (patch)
tree9cd11c96eb4c179e76f2e586d5a9d9b416dd85a0 /include/armnn/Descriptors.hpp
parent2883a86c5a167aea3c736529bff5921ab6cbc99c (diff)
downloadarmnn-282881877522d3e94752dfc0839de9bfa0aa5a81.tar.gz
IVGCVSW-8229 & IVGCVSW-8237 ScatterNd: Front end and reference implementation
(scatter_nd, scatter_nd_add, and scatter_nd_update, scatter_nd_sub, scatter_nd_min, scatter_nd_max, scatter_nd_mul) * Front end support for ScatterNd added. * Reference implementation for ScatterNd added. * Unit tests added. Signed-off-by: Tianle Cheng <tianle.cheng@arm.com> Change-Id: I30da9056d9b03ca9b5fb8d09987341128badbcf4
Diffstat (limited to 'include/armnn/Descriptors.hpp')
-rw-r--r--include/armnn/Descriptors.hpp55
1 files changed, 54 insertions, 1 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index bf40b35ae9..7230bc2c1d 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -1675,4 +1675,57 @@ struct BroadcastToDescriptor : BaseDescriptor
TensorShape m_BroadcastToShape;
};
+/// A ScatterNdDescriptor for the ScatterNdLayer.
+struct ScatterNdDescriptor : BaseDescriptor
+{
+ // default constructor
+ ScatterNdDescriptor()
+ : m_Function(ScatterNdFunction::Update)
+ , m_InputEnabled(true)
+ , m_Axis(0)
+ , m_AxisEnabled(false)
+ {}
+
+ // constructor for operators except for ScatterElement operator
+ ScatterNdDescriptor(ScatterNdFunction function,
+ bool inputEnabled)
+ : m_Function(function)
+ , m_InputEnabled(inputEnabled)
+ , m_Axis(0)
+ , m_AxisEnabled(false)
+
+ {}
+
+ // constructor for ScatterElement operator
+ ScatterNdDescriptor(ScatterNdFunction function,
+ bool inputEnabled,
+ int32_t axis)
+ : m_Function(function)
+ , m_InputEnabled(inputEnabled)
+ , m_Axis(axis)
+ , m_AxisEnabled(true)
+
+ {}
+
+ bool operator ==(const ScatterNdDescriptor &rhs) const
+ {
+ return ((m_Function == rhs.m_Function) &&
+ (m_InputEnabled == rhs.m_InputEnabled) &&
+ (m_AxisEnabled == rhs.m_AxisEnabled) &&
+ (m_Axis == rhs.m_Axis));
+ }
+
+ /// Specify if the function is update, add, sub, max or min.
+ ScatterNdFunction m_Function;
+
+ /// Flag to show if input tensor is accepted.
+ bool m_InputEnabled;
+
+ /// Extra attribute for ScatterElement, will be set to 0 by default, we do not support axis != 0
+ int32_t m_Axis;
+
+ /// Flag for ScatterElement, will be set to false by default, we do not support m_AxisEnable = true for now.
+ bool m_AxisEnabled;
+};
+
} // namespace armnn