From ee391d59dbe3305734de4ff7d98c27c8a5252624 Mon Sep 17 00:00:00 2001 From: Nikhil Raj Date: Thu, 5 Sep 2019 17:50:44 +0100 Subject: IVGCVSW-3722 Add front end support for ArgMinMax Change-Id: I31c5616bea3097f30cde68442d3222e0b0fe2235 Signed-off-by: Nikhil Raj --- include/armnn/Descriptors.hpp | 9 +++++++++ include/armnn/DescriptorsFwd.hpp | 1 + include/armnn/ILayerSupport.hpp | 6 ++++++ include/armnn/ILayerVisitor.hpp | 8 ++++++++ include/armnn/INetwork.hpp | 7 +++++++ include/armnn/LayerVisitorBase.hpp | 4 ++++ 6 files changed, 35 insertions(+) (limited to 'include') diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index 9630d86197..87f4bdb40e 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -30,6 +30,15 @@ struct ActivationDescriptor float m_B; }; +/// An ArgMinMaxDescriptor for ArgMinMaxLayer +struct ArgMinMaxDescriptor +{ + ArgMinMaxDescriptor() + : m_Axis(-1) {} + + int m_Axis; +}; + /// A PermuteDescriptor for the PermuteLayer. struct PermuteDescriptor { diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp index eddf91f4ce..8f81b4fe3e 100644 --- a/include/armnn/DescriptorsFwd.hpp +++ b/include/armnn/DescriptorsFwd.hpp @@ -8,6 +8,7 @@ namespace armnn { struct ActivationDescriptor; +struct ArgMinMaxDescriptor; struct BatchNormalizationDescriptor; struct BatchToSpaceNdDescriptor; struct Convolution2dDescriptor; diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp index c67569bf00..d168226402 100644 --- a/include/armnn/ILayerSupport.hpp +++ b/include/armnn/ILayerSupport.hpp @@ -14,6 +14,7 @@ #include #include #include +#include "ArmNN.hpp" namespace armnn { @@ -41,6 +42,11 @@ public: const TensorInfo& output, Optional reasonIfUnsupported = EmptyOptional()) const = 0; + virtual bool IsArgMinMaxSupported(const TensorInfo& input, + const TensorInfo& output, + const ArgMinMaxDescriptor& descriptor, + Optional reasonIfUnsupported = EmptyOptional()) const = 0; + virtual bool IsBatchNormalizationSupported(const TensorInfo& input, const TensorInfo& output, const TensorInfo& mean, diff --git a/include/armnn/ILayerVisitor.hpp b/include/armnn/ILayerVisitor.hpp index a22de878ca..a504a4190d 100644 --- a/include/armnn/ILayerVisitor.hpp +++ b/include/armnn/ILayerVisitor.hpp @@ -41,6 +41,14 @@ public: virtual void VisitAdditionLayer(const IConnectableLayer* layer, const char* name = nullptr) = 0; + /// Function that an arg min max layer should call back to when its Accept(ILayerVisitor&) function is invoked. + /// @param layer - pointer to the layer which is calling back to this visit function. + /// @param argMinMaxDescriptor - ArgMinMaxDescriptor to configure the activation. + /// @param name - Optional name for the layer. + virtual void VisitArgMinMaxLayer(const IConnectableLayer* layer, + const ArgMinMaxDescriptor& argMinMaxDescriptor, + const char* name = nullptr) = 0; + /// Function that a batch normalization layer should call back to when its Accept(ILayerVisitor&) /// function is invoked. /// @param layer - pointer to the layer which is calling back to this visit function. diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp index ce0fda2707..cd1b7a6319 100644 --- a/include/armnn/INetwork.hpp +++ b/include/armnn/INetwork.hpp @@ -102,6 +102,13 @@ public: /// @return - Interface for configuring the layer. virtual IConnectableLayer* AddInputLayer(LayerBindingId id, const char* name = nullptr) = 0; + /// Adds an ArgMinMax layer to the network. + /// @param desc - Parameters for the L2 normalization operation. + /// @param name - Optional name for the layer. + /// @return - Interface for configuring the layer. + virtual IConnectableLayer* AddArgMinMaxLayer(const ArgMinMaxDescriptor& desc, + const char* name = nullptr) = 0; + /// Adds a concatenation layer to the network. /// @param concatDescriptor - ConcatDescriptor (synonym for OriginsDescriptor) to configure the concatenation /// process. Number of Views must be equal to the number of inputs, and their order diff --git a/include/armnn/LayerVisitorBase.hpp b/include/armnn/LayerVisitorBase.hpp index 363a09154d..0739b43736 100644 --- a/include/armnn/LayerVisitorBase.hpp +++ b/include/armnn/LayerVisitorBase.hpp @@ -39,6 +39,10 @@ public: void VisitAdditionLayer(const IConnectableLayer*, const char*) override { DefaultPolicy::Apply(__func__); } + void VisitArgMinMaxLayer(const IConnectableLayer*, + const ArgMinMaxDescriptor&, + const char*) override { DefaultPolicy::Apply(__func__); } + void VisitBatchNormalizationLayer(const IConnectableLayer*, const BatchNormalizationDescriptor&, const ConstTensor&, -- cgit v1.2.1