aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2022-04-12 22:07:09 +0100
committerTeresa Charlin <teresa.charlinreyes@arm.com>2022-05-03 21:24:52 +0100
commitb2d3ec5b1e938ef34facfdbcff83fc8e845d5f7c (patch)
tree74ee2c47e76fddff249a9f25db01960a52eb2360 /include
parent04cd60384f5fc8455bb7cf64416daa7b001754d1 (diff)
downloadarmnn-b2d3ec5b1e938ef34facfdbcff83fc8e845d5f7c.tar.gz
IVGCVSW-6856 Add GATHERNd FrontEnd and Ref Implementation
* Add front end * Add reference workload * Add unit tests * Add EndToEnd test Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: I4cebd17b18476df86162e2dda3366c10e80bd2f8
Diffstat (limited to 'include')
-rw-r--r--include/armnn/BackendHelper.hpp5
-rw-r--r--include/armnn/INetwork.hpp5
-rw-r--r--include/armnn/Types.hpp1
-rw-r--r--include/armnn/backends/WorkloadData.hpp5
4 files changed, 16 insertions, 0 deletions
diff --git a/include/armnn/BackendHelper.hpp b/include/armnn/BackendHelper.hpp
index 0c625a6062..4772ca97cd 100644
--- a/include/armnn/BackendHelper.hpp
+++ b/include/armnn/BackendHelper.hpp
@@ -185,6 +185,11 @@ public:
const GatherDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional());
+ bool IsGatherNdSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional());
+
bool IsInputSupported(const TensorInfo& input,
Optional<std::string&> reasonIfUnsupported = EmptyOptional());
diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp
index eaec973899..7488fdc026 100644
--- a/include/armnn/INetwork.hpp
+++ b/include/armnn/INetwork.hpp
@@ -651,6 +651,11 @@ public:
IConnectableLayer* AddGatherLayer(const GatherDescriptor& descriptor,
const char* name = nullptr);
+ /// Add GatherNd layer to the network.
+ /// @param name - Optional name for the layer.
+ /// @return - Interface for configuring the layer.
+ IConnectableLayer* AddGatherNdLayer(const char* name = nullptr);
+
/// Adds a switch layer to the network.
/// @param name - Optional name for the layer.
/// @return - Interface for configuring the layer.
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp
index a804f55468..cc704a64ae 100644
--- a/include/armnn/Types.hpp
+++ b/include/armnn/Types.hpp
@@ -458,6 +458,7 @@ using InferenceTimingPair = std::pair<HighResolutionClock, HighResolutionClock>;
X(ChannelShuffle) \
X(Convolution3d) \
X(Pooling3d) \
+ X(GatherNd)\
// New layers should be added at last to minimize instability.
diff --git a/include/armnn/backends/WorkloadData.hpp b/include/armnn/backends/WorkloadData.hpp
index 21141583c6..ed89f9638c 100644
--- a/include/armnn/backends/WorkloadData.hpp
+++ b/include/armnn/backends/WorkloadData.hpp
@@ -527,6 +527,11 @@ struct RsqrtQueueDescriptor : QueueDescriptor
void Validate(const WorkloadInfo& workloadInfo) const;
};
+struct GatherNdQueueDescriptor : QueueDescriptor
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
struct GatherQueueDescriptor : QueueDescriptorWithParameters<GatherDescriptor>
{
void Validate(const WorkloadInfo& workloadInfo) const;