aboutsummaryrefslogtreecommitdiff
path: root/include/armnn
diff options
context:
space:
mode:
Diffstat (limited to 'include/armnn')
-rw-r--r--include/armnn/BackendHelper.hpp5
-rw-r--r--include/armnn/INetwork.hpp5
-rw-r--r--include/armnn/Types.hpp1
-rw-r--r--include/armnn/backends/WorkloadData.hpp5
4 files changed, 16 insertions, 0 deletions
diff --git a/include/armnn/BackendHelper.hpp b/include/armnn/BackendHelper.hpp
index 0c625a6062..4772ca97cd 100644
--- a/include/armnn/BackendHelper.hpp
+++ b/include/armnn/BackendHelper.hpp
@@ -185,6 +185,11 @@ public:
const GatherDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional());
+ bool IsGatherNdSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional());
+
bool IsInputSupported(const TensorInfo& input,
Optional<std::string&> reasonIfUnsupported = EmptyOptional());
diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp
index eaec973899..7488fdc026 100644
--- a/include/armnn/INetwork.hpp
+++ b/include/armnn/INetwork.hpp
@@ -651,6 +651,11 @@ public:
IConnectableLayer* AddGatherLayer(const GatherDescriptor& descriptor,
const char* name = nullptr);
+ /// Add GatherNd layer to the network.
+ /// @param name - Optional name for the layer.
+ /// @return - Interface for configuring the layer.
+ IConnectableLayer* AddGatherNdLayer(const char* name = nullptr);
+
/// Adds a switch layer to the network.
/// @param name - Optional name for the layer.
/// @return - Interface for configuring the layer.
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp
index a804f55468..cc704a64ae 100644
--- a/include/armnn/Types.hpp
+++ b/include/armnn/Types.hpp
@@ -458,6 +458,7 @@ using InferenceTimingPair = std::pair<HighResolutionClock, HighResolutionClock>;
X(ChannelShuffle) \
X(Convolution3d) \
X(Pooling3d) \
+ X(GatherNd)\
// New layers should be added at last to minimize instability.
diff --git a/include/armnn/backends/WorkloadData.hpp b/include/armnn/backends/WorkloadData.hpp
index 21141583c6..ed89f9638c 100644
--- a/include/armnn/backends/WorkloadData.hpp
+++ b/include/armnn/backends/WorkloadData.hpp
@@ -527,6 +527,11 @@ struct RsqrtQueueDescriptor : QueueDescriptor
void Validate(const WorkloadInfo& workloadInfo) const;
};
+struct GatherNdQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
struct GatherQueueDescriptor : QueueDescriptorWithParameters<GatherDescriptor>
{
void Validate(const WorkloadInfo& workloadInfo) const;