diff options
Diffstat (limited to 'include')
-rw-r--r-- | include/armnn/Descriptors.hpp | 20 | ||||
-rw-r--r-- | include/armnn/DescriptorsFwd.hpp | 1 | ||||
-rw-r--r-- | include/armnn/ILayerSupport.hpp | 8 | ||||
-rw-r--r-- | include/armnn/ILayerVisitor.hpp | 10 | ||||
-rw-r--r-- | include/armnn/INetwork.hpp | 9 | ||||
-rw-r--r-- | include/armnn/LayerVisitorBase.hpp | 4 | ||||
-rw-r--r-- | include/armnn/Types.hpp | 10 | ||||
-rw-r--r-- | include/armnn/TypesUtils.hpp | 14 |
8 files changed, 76 insertions, 0 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index 92e842b2c1..10d8ab7a08 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -58,6 +58,26 @@ struct ArgMinMaxDescriptor int m_Axis; }; +/// A ComparisonDescriptor for the ComparisonLayer +struct ComparisonDescriptor +{ + ComparisonDescriptor() + : ComparisonDescriptor(ComparisonOperation::Equal) + {} + + ComparisonDescriptor(ComparisonOperation operation) + : m_Operation(operation) + {} + + bool operator ==(const ComparisonDescriptor &rhs) const + { + return m_Operation == rhs.m_Operation; + } + + /// Specifies the comparison operation to execute + ComparisonOperation m_Operation; +}; + /// A PermuteDescriptor for the PermuteLayer. struct PermuteDescriptor { diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp index 6f1c0e0a6e..a978f7739a 100644 --- a/include/armnn/DescriptorsFwd.hpp +++ b/include/armnn/DescriptorsFwd.hpp @@ -12,6 +12,7 @@ struct ActivationDescriptor; struct ArgMinMaxDescriptor; struct BatchNormalizationDescriptor; struct BatchToSpaceNdDescriptor; +struct ComparisonDescriptor; struct Convolution2dDescriptor; struct DepthwiseConvolution2dDescriptor; struct DetectionPostProcessDescriptor; diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp index 31b5e134e9..87197eed5a 100644 --- a/include/armnn/ILayerSupport.hpp +++ b/include/armnn/ILayerSupport.hpp @@ -61,6 +61,12 @@ public: const BatchToSpaceNdDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; + virtual bool IsComparisonSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + const ComparisonDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; + virtual bool IsConcatSupported(const std::vector<const TensorInfo*> inputs, const TensorInfo& output, const OriginsDescriptor& descriptor, @@ -124,6 +130,7 @@ public: const TensorInfo& output, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; + ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead") virtual bool IsEqualSupported(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, @@ -149,6 +156,7 @@ public: const TensorInfo& output, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; + ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead") virtual bool IsGreaterSupported(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& ouput, diff --git a/include/armnn/ILayerVisitor.hpp b/include/armnn/ILayerVisitor.hpp index e99e10f800..80931ebcc0 100644 --- a/include/armnn/ILayerVisitor.hpp +++ b/include/armnn/ILayerVisitor.hpp @@ -74,6 +74,14 @@ public: const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor, const char* name = nullptr) = 0; + /// Function a Comparison layer should call back to when its Accept(ILayerVisitor&) function is invoked. + /// @param layer - pointer to the layer which is calling back to this visit function. + /// @param comparisonDescriptor - Description of the layer. + /// @param name - Optional name for the layer. + virtual void VisitComparisonLayer(const IConnectableLayer* layer, + const ComparisonDescriptor& comparisonDescriptor, + const char* name = nullptr) = 0; + /// Function that a concat layer should call back to when its Accept(ILayerVisitor&) function is invoked. /// @param layer - pointer to the layer which is calling back to this visit function. /// @param concatDescriptor - ConcatDescriptor (synonym for OriginsDescriptor) to configure the concatenation @@ -163,6 +171,7 @@ public: /// Function an Equal layer should call back to when its Accept(ILayerVisitor&) function is invoked. /// @param layer - pointer to the layer which is calling back to this visit function. /// @param name - Optional name for the layer. + ARMNN_DEPRECATED_MSG("Use VisitComparisonLayer instead") virtual void VisitEqualLayer(const IConnectableLayer* layer, const char* name = nullptr) = 0; @@ -194,6 +203,7 @@ public: /// Function a Greater layer should call back to when its Accept(ILayerVisitor&) function is invoked. /// @param layer - pointer to the layer which is calling back to this visit function. /// @param name - Optional name for the layer. + ARMNN_DEPRECATED_MSG("Use VisitComparisonLayer instead") virtual void VisitGreaterLayer(const IConnectableLayer* layer, const char* name = nullptr) = 0; diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp index d12f5c239c..b3fab82cb4 100644 --- a/include/armnn/INetwork.hpp +++ b/include/armnn/INetwork.hpp @@ -109,6 +109,13 @@ public: virtual IConnectableLayer* AddArgMinMaxLayer(const ArgMinMaxDescriptor& desc, const char* name = nullptr) = 0; + /// Add a Comparison layer to the network. + /// @param name - Optional name for the layer. + /// @param desc - Descriptor for the comparison operation. + /// @ return - Interface for configuring the layer. + virtual IConnectableLayer* AddComparisonLayer(const ComparisonDescriptor& comparisonDescriptor, + const char* name = nullptr) = 0; + /// Adds a concatenation layer to the network. /// @param concatDescriptor - ConcatDescriptor (synonym for OriginsDescriptor) to configure the concatenation /// process. Number of Views must be equal to the number of inputs, and their order @@ -453,11 +460,13 @@ public: /// Add a Greater layer to the network. /// @param name - Optional name for the layer. /// @ return - Interface for configuring the layer. + ARMNN_DEPRECATED_MSG("Use AddComparisonLayer instead") virtual IConnectableLayer* AddGreaterLayer(const char* name = nullptr) = 0; /// Add a Equal layer to the network. /// @param name - Optional name for the layer. /// @ return - Interface for configuring the layer. + ARMNN_DEPRECATED_MSG("Use AddComparisonLayer instead") virtual IConnectableLayer* AddEqualLayer(const char* name = nullptr) = 0; /// Add Reciprocal of square root layer to the network. diff --git a/include/armnn/LayerVisitorBase.hpp b/include/armnn/LayerVisitorBase.hpp index 912f25500c..5226fa2f66 100644 --- a/include/armnn/LayerVisitorBase.hpp +++ b/include/armnn/LayerVisitorBase.hpp @@ -55,6 +55,10 @@ public: const BatchToSpaceNdDescriptor&, const char*) override { DefaultPolicy::Apply(__func__); } + void VisitComparisonLayer(const IConnectableLayer*, + const ComparisonDescriptor&, + const char*) override { DefaultPolicy::Apply(__func__); } + void VisitConcatLayer(const IConnectableLayer*, const ConcatDescriptor&, const char*) override { DefaultPolicy::Apply(__func__); } diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp index dbcb91aebe..16a148c9c2 100644 --- a/include/armnn/Types.hpp +++ b/include/armnn/Types.hpp @@ -60,6 +60,16 @@ enum class ArgMinMaxFunction Max = 1 }; +enum class ComparisonOperation +{ + Equal = 0, + Greater = 1, + GreaterOrEqual = 2, + Less = 3, + LessOrEqual = 4, + NotEqual = 5 +}; + enum class PoolingAlgorithm { Max = 0, diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp index cb52471cd5..310792c2a4 100644 --- a/include/armnn/TypesUtils.hpp +++ b/include/armnn/TypesUtils.hpp @@ -42,6 +42,20 @@ constexpr char const* GetActivationFunctionAsCString(ActivationFunction activati } } +constexpr char const* GetComparisonOperationAsCString(ComparisonOperation operation) +{ + switch (operation) + { + case ComparisonOperation::Equal: return "Equal"; + case ComparisonOperation::Greater: return "Greater"; + case ComparisonOperation::GreaterOrEqual: return "GreaterOrEqual"; + case ComparisonOperation::Less: return "Less"; + case ComparisonOperation::LessOrEqual: return "LessOrEqual"; + case ComparisonOperation::NotEqual: return "NotEqual"; + default: return "Unknown"; + } +} + constexpr char const* GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling) { switch (pooling) |