aboutsummaryrefslogtreecommitdiff
path: root/include/armnn/Descriptors.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'include/armnn/Descriptors.hpp')
-rw-r--r--include/armnn/Descriptors.hpp55
1 files changed, 54 insertions, 1 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index bf40b35ae9..7230bc2c1d 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2017-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -1675,4 +1675,57 @@ struct BroadcastToDescriptor : BaseDescriptor
TensorShape m_BroadcastToShape;
};
+/// A ScatterNdDescriptor for the ScatterNdLayer.
+struct ScatterNdDescriptor : BaseDescriptor
+{
+ // default constructor
+ ScatterNdDescriptor()
+ : m_Function(ScatterNdFunction::Update)
+ , m_InputEnabled(true)
+ , m_Axis(0)
+ , m_AxisEnabled(false)
+ {}
+
+ // constructor for operators except for ScatterElement operator
+ ScatterNdDescriptor(ScatterNdFunction function,
+ bool inputEnabled)
+ : m_Function(function)
+ , m_InputEnabled(inputEnabled)
+ , m_Axis(0)
+ , m_AxisEnabled(false)
+
+ {}
+
+ // constructor for ScatterElement operator
+ ScatterNdDescriptor(ScatterNdFunction function,
+ bool inputEnabled,
+ int32_t axis)
+ : m_Function(function)
+ , m_InputEnabled(inputEnabled)
+ , m_Axis(axis)
+ , m_AxisEnabled(true)
+
+ {}
+
+ bool operator ==(const ScatterNdDescriptor &rhs) const
+ {
+ return ((m_Function == rhs.m_Function) &&
+ (m_InputEnabled == rhs.m_InputEnabled) &&
+ (m_AxisEnabled == rhs.m_AxisEnabled) &&
+ (m_Axis == rhs.m_Axis));
+ }
+
+ /// Specify if the function is update, add, sub, max or min.
+ ScatterNdFunction m_Function;
+
+ /// Flag to show if input tensor is accepted.
+ bool m_InputEnabled;
+
+ /// Extra attribute for ScatterElement, will be set to 0 by default, we do not support axis != 0
+ int32_t m_Axis;
+
+ /// Flag for ScatterElement, will be set to false by default, we do not support m_AxisEnable = true for now.
+ bool m_AxisEnabled;
+};
+
} // namespace armnn