From 4a3c61091037e7e86e8b03bb060d8c1ab82731a9 Mon Sep 17 00:00:00 2001 From: josh minor Date: Mon, 6 Jan 2020 16:40:46 -0600 Subject: IVGCVSW-4259 Add frontend and reference workload for UnaryOperationLayer * Added new layer named ElementwiseUnary * Deprecated existing Abs/Rsqrt layer functions * Updated existing Abs/Rsqrt test infrastructure to use new layer * Added boilerplate for new Exp,Neg,Sqrt elemwise op layers * AbsQuantize test removed pending future commit * Serialization support added !android-nn-driver:2550 Change-Id: Ic595c645925e17b45db568187fd05646daf2e87f Signed-off-by: josh minor --- include/armnn/Descriptors.hpp | 20 ++++++++++++++++++++ include/armnn/DescriptorsFwd.hpp | 1 + include/armnn/ILayerSupport.hpp | 7 +++++++ include/armnn/ILayerVisitor.hpp | 10 ++++++++++ include/armnn/INetwork.hpp | 9 +++++++++ include/armnn/LayerVisitorBase.hpp | 4 ++++ include/armnn/Types.hpp | 9 +++++++++ include/armnn/TypesUtils.hpp | 13 +++++++++++++ 8 files changed, 73 insertions(+) (limited to 'include/armnn') diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index ba9a56ad38..45c0f421f3 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -78,6 +78,26 @@ struct ComparisonDescriptor ComparisonOperation m_Operation; }; +/// A ElementwiseUnaryDescriptor for the ElementwiseUnaryLayer +struct ElementwiseUnaryDescriptor +{ + ElementwiseUnaryDescriptor() + : ElementwiseUnaryDescriptor(UnaryOperation::Abs) + {} + + ElementwiseUnaryDescriptor(UnaryOperation operation) + : m_Operation(operation) + {} + + bool operator ==(const ElementwiseUnaryDescriptor &rhs) const + { + return m_Operation == rhs.m_Operation; + } + + /// Specifies the elementwiseUnary operation to execute + UnaryOperation m_Operation; +}; + /// A PermuteDescriptor for the PermuteLayer. struct PermuteDescriptor { diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp index cfdef8a030..d03c61d452 100644 --- a/include/armnn/DescriptorsFwd.hpp +++ b/include/armnn/DescriptorsFwd.hpp @@ -16,6 +16,7 @@ struct ComparisonDescriptor; struct Convolution2dDescriptor; struct DepthwiseConvolution2dDescriptor; struct DetectionPostProcessDescriptor; +struct ElementwiseUnaryDescriptor; struct FakeQuantizationDescriptor; struct FullyConnectedDescriptor; struct InstanceNormalizationDescriptor; diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp index 452200291e..1615d3e24e 100644 --- a/include/armnn/ILayerSupport.hpp +++ b/include/armnn/ILayerSupport.hpp @@ -27,6 +27,7 @@ protected: virtual ~ILayerSupport() {} public: + ARMNN_DEPRECATED_MSG("Use IsElementwiseUnarySupported instead") virtual bool IsAbsSupported(const TensorInfo& input, const TensorInfo& output, Optional reasonIfUnsupported = EmptyOptional()) const = 0; @@ -133,6 +134,11 @@ public: const TensorInfo& output, Optional reasonIfUnsupported = EmptyOptional()) const = 0; + virtual bool IsElementwiseUnarySupported(const TensorInfo& input, + const TensorInfo& output, + const ElementwiseUnaryDescriptor& descriptor, + Optional reasonIfUnsupported = EmptyOptional()) const = 0; + ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead") virtual bool IsEqualSupported(const TensorInfo& input0, const TensorInfo& input1, @@ -292,6 +298,7 @@ public: const ResizeDescriptor& descriptor, Optional reasonIfUnsupported = EmptyOptional()) const = 0; + ARMNN_DEPRECATED_MSG("Use IsElementwiseUnarySupported instead") virtual bool IsRsqrtSupported(const TensorInfo& input, const TensorInfo& output, Optional reasonIfUnsupported = EmptyOptional()) const = 0; diff --git a/include/armnn/ILayerVisitor.hpp b/include/armnn/ILayerVisitor.hpp index 9669b3a7cb..46f9e5698f 100644 --- a/include/armnn/ILayerVisitor.hpp +++ b/include/armnn/ILayerVisitor.hpp @@ -24,6 +24,7 @@ public: /// function is invoked. /// @param layer - pointer to the layer which is calling back to this visit function. /// @param name - Optional name for the layer. + ARMNN_DEPRECATED_MSG("Use VisitElementwiseUnaryLayer instead") virtual void VisitAbsLayer(const IConnectableLayer* layer, const char* name = nullptr) = 0; @@ -168,6 +169,14 @@ public: virtual void VisitDivisionLayer(const IConnectableLayer* layer, const char* name = nullptr) = 0; + /// Function a ElementwiseUnary layer should call back to when its Accept(ILayerVisitor&) function is invoked. + /// @param layer - pointer to the layer which is calling back to this visit function. + /// @param elementwiseUnaryDescriptor - Description of the layer. + /// @param name - Optional name for the layer. + virtual void VisitElementwiseUnaryLayer(const IConnectableLayer* layer, + const ElementwiseUnaryDescriptor& elementwiseUnaryDescriptor, + const char* name = nullptr) = 0; + /// Function an Equal layer should call back to when its Accept(ILayerVisitor&) function is invoked. /// @param layer - pointer to the layer which is calling back to this visit function. /// @param name - Optional name for the layer. @@ -388,6 +397,7 @@ public: /// function is invoked. /// @param layer - pointer to the layer which is calling back to this visit function. /// @param name - Optional name for the layer. + ARMNN_DEPRECATED_MSG("Use VisitElementwiseUnaryLayer instead") virtual void VisitRsqrtLayer(const IConnectableLayer* layer, const char* name = nullptr) = 0; diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp index 647f072804..1b1c874f8c 100644 --- a/include/armnn/INetwork.hpp +++ b/include/armnn/INetwork.hpp @@ -196,6 +196,13 @@ public: const ConstTensor& anchors, const char* name = nullptr) = 0; + /// Add an ElementwiseUnary layer to the network. + /// @param name - Optional name for the layer. + /// @param desc - Descriptor for the elementwiseUnary operation. + /// @ return - Interface for configuring the layer. + virtual IConnectableLayer* AddElementwiseUnaryLayer(const ElementwiseUnaryDescriptor& elementwiseUnaryDescriptor, + const char* name = nullptr) = 0; + /// Adds a fully connected layer to the network. /// @param fullyConnectedDescriptor - Description of the fully connected layer. /// @param weights - Tensor for the weights data. @@ -297,6 +304,7 @@ public: /// Add absolute layer to the network. /// @param name - Optional name for the layer. /// @ return - Interface for configuring the layer. + ARMNN_DEPRECATED_MSG("Use AddElementwiseUnaryLayer instead") virtual IConnectableLayer* AddAbsLayer(const char* name = nullptr) = 0; /// Adds an addition layer to the network. @@ -474,6 +482,7 @@ public: /// Add Reciprocal of square root layer to the network. /// @param name - Optional name for the layer. /// @ return - Interface for configuring the layer. + ARMNN_DEPRECATED_MSG("Use AddElementwiseUnaryLayer instead") virtual IConnectableLayer* AddRsqrtLayer(const char* name = nullptr) = 0; /// Add Gather layer to the network. diff --git a/include/armnn/LayerVisitorBase.hpp b/include/armnn/LayerVisitorBase.hpp index 388fc6f922..6fd9a66c76 100644 --- a/include/armnn/LayerVisitorBase.hpp +++ b/include/armnn/LayerVisitorBase.hpp @@ -94,6 +94,10 @@ public: void VisitDivisionLayer(const IConnectableLayer*, const char*) override { DefaultPolicy::Apply(__func__); } + void VisitElementwiseUnaryLayer(const IConnectableLayer*, + const ElementwiseUnaryDescriptor&, + const char*) override { DefaultPolicy::Apply(__func__); } + void VisitEqualLayer(const IConnectableLayer*, const char*) override { DefaultPolicy::Apply(__func__); } diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp index 5ea214e1dc..1ab5660109 100644 --- a/include/armnn/Types.hpp +++ b/include/armnn/Types.hpp @@ -80,6 +80,15 @@ enum class ComparisonOperation NotEqual = 5 }; +enum class UnaryOperation +{ + Abs = 0, + Exp = 1, + Sqrt = 2, + Rsqrt = 3, + Neg = 4 +}; + enum class PoolingAlgorithm { Max = 0, diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp index 8157d4f043..790f57a432 100644 --- a/include/armnn/TypesUtils.hpp +++ b/include/armnn/TypesUtils.hpp @@ -66,6 +66,19 @@ constexpr char const* GetComparisonOperationAsCString(ComparisonOperation operat } } +constexpr char const* GetUnaryOperationAsCString(UnaryOperation operation) +{ + switch (operation) + { + case UnaryOperation::Abs: return "Abs"; + case UnaryOperation::Exp: return "Exp"; + case UnaryOperation::Sqrt: return "Sqrt"; + case UnaryOperation::Rsqrt: return "Rsqrt"; + case UnaryOperation::Neg: return "Neg"; + default: return "Unknown"; + } +} + constexpr char const* GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling) { switch (pooling) -- cgit v1.2.1