From b2d3ec5b1e938ef34facfdbcff83fc8e845d5f7c Mon Sep 17 00:00:00 2001 From: Teresa Charlin Date: Tue, 12 Apr 2022 22:07:09 +0100 Subject: IVGCVSW-6856 Add GATHERNd FrontEnd and Ref Implementation * Add front end * Add reference workload * Add unit tests * Add EndToEnd test Signed-off-by: Teresa Charlin Change-Id: I4cebd17b18476df86162e2dda3366c10e80bd2f8 --- include/armnn/BackendHelper.hpp | 5 +++++ include/armnn/INetwork.hpp | 5 +++++ include/armnn/Types.hpp | 1 + include/armnn/backends/WorkloadData.hpp | 5 +++++ 4 files changed, 16 insertions(+) (limited to 'include') diff --git a/include/armnn/BackendHelper.hpp b/include/armnn/BackendHelper.hpp index 0c625a6062..4772ca97cd 100644 --- a/include/armnn/BackendHelper.hpp +++ b/include/armnn/BackendHelper.hpp @@ -185,6 +185,11 @@ public: const GatherDescriptor& descriptor, Optional reasonIfUnsupported = EmptyOptional()); + bool IsGatherNdSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + Optional reasonIfUnsupported = EmptyOptional()); + bool IsInputSupported(const TensorInfo& input, Optional reasonIfUnsupported = EmptyOptional()); diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp index eaec973899..7488fdc026 100644 --- a/include/armnn/INetwork.hpp +++ b/include/armnn/INetwork.hpp @@ -651,6 +651,11 @@ public: IConnectableLayer* AddGatherLayer(const GatherDescriptor& descriptor, const char* name = nullptr); + /// Add GatherNd layer to the network. + /// @param name - Optional name for the layer. + /// @return - Interface for configuring the layer. + IConnectableLayer* AddGatherNdLayer(const char* name = nullptr); + /// Adds a switch layer to the network. /// @param name - Optional name for the layer. /// @return - Interface for configuring the layer. diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp index a804f55468..cc704a64ae 100644 --- a/include/armnn/Types.hpp +++ b/include/armnn/Types.hpp @@ -458,6 +458,7 @@ using InferenceTimingPair = std::pair; X(ChannelShuffle) \ X(Convolution3d) \ X(Pooling3d) \ + X(GatherNd)\ // New layers should be added at last to minimize instability. diff --git a/include/armnn/backends/WorkloadData.hpp b/include/armnn/backends/WorkloadData.hpp index 21141583c6..ed89f9638c 100644 --- a/include/armnn/backends/WorkloadData.hpp +++ b/include/armnn/backends/WorkloadData.hpp @@ -527,6 +527,11 @@ struct RsqrtQueueDescriptor : QueueDescriptor void Validate(const WorkloadInfo& workloadInfo) const; }; +struct GatherNdQueueDescriptor : QueueDescriptor +{ + void Validate(const WorkloadInfo& workloadInfo) const; +}; + struct GatherQueueDescriptor : QueueDescriptorWithParameters { void Validate(const WorkloadInfo& workloadInfo) const; -- cgit v1.2.1