diff options
Diffstat (limited to 'include/armnn')
-rw-r--r-- | include/armnn/Descriptors.hpp | 20 | ||||
-rw-r--r-- | include/armnn/DescriptorsFwd.hpp | 1 | ||||
-rw-r--r-- | include/armnn/ILayerSupport.hpp | 11 | ||||
-rw-r--r-- | include/armnn/ILayerVisitor.hpp | 8 | ||||
-rw-r--r-- | include/armnn/INetwork.hpp | 7 | ||||
-rw-r--r-- | include/armnn/LayerVisitorBase.hpp | 4 | ||||
-rw-r--r-- | include/armnn/Types.hpp | 17 | ||||
-rw-r--r-- | include/armnn/TypesUtils.hpp | 23 |
8 files changed, 80 insertions, 11 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp index 2834336fb2..ac0d585751 100644 --- a/include/armnn/Descriptors.hpp +++ b/include/armnn/Descriptors.hpp @@ -1266,4 +1266,24 @@ struct TransposeDescriptor PermutationVector m_DimMappings; }; +/// A LogicalBinaryDescriptor for the LogicalBinaryLayer +struct LogicalBinaryDescriptor +{ + LogicalBinaryDescriptor() + : LogicalBinaryDescriptor(LogicalBinaryOperation::LogicalAnd) + {} + + LogicalBinaryDescriptor(LogicalBinaryOperation operation) + : m_Operation(operation) + {} + + bool operator ==(const LogicalBinaryDescriptor &rhs) const + { + return m_Operation == rhs.m_Operation; + } + + /// Specifies the logical operation to execute + LogicalBinaryOperation m_Operation; +}; + } // namespace armnn diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp index fba976c788..dff5ec73ad 100644 --- a/include/armnn/DescriptorsFwd.hpp +++ b/include/armnn/DescriptorsFwd.hpp @@ -23,6 +23,7 @@ struct FullyConnectedDescriptor; struct GatherDescriptor; struct InstanceNormalizationDescriptor; struct L2NormalizationDescriptor; +struct LogicalBinaryDescriptor; struct LstmDescriptor; struct MeanDescriptor; struct NormalizationDescriptor; diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp index ed234fe4db..200409361c 100644 --- a/include/armnn/ILayerSupport.hpp +++ b/include/armnn/ILayerSupport.hpp @@ -205,6 +205,17 @@ public: const L2NormalizationDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; + virtual bool IsLogicalBinarySupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + const LogicalBinaryDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; + + virtual bool IsLogicalUnarySupported(const TensorInfo& input, + const TensorInfo& output, + const ElementwiseUnaryDescriptor& descriptor, + Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0; + virtual bool IsLogSoftmaxSupported(const TensorInfo& input, const TensorInfo& output, const LogSoftmaxDescriptor& descriptor, diff --git a/include/armnn/ILayerVisitor.hpp b/include/armnn/ILayerVisitor.hpp index 385ad62225..b8e741d01e 100644 --- a/include/armnn/ILayerVisitor.hpp +++ b/include/armnn/ILayerVisitor.hpp @@ -268,6 +268,14 @@ public: const LogSoftmaxDescriptor& logSoftmaxDescriptor, const char* name = nullptr) = 0; + /// Function that a logical binary layer should call back to when its Accept(ILayerVisitor&) function is invoked. + /// @param layer - pointer to the layer which is calling back to this visit function. + /// @param logicalBinaryDescriptor - LogicalBinaryDescriptor to configure the logical unary layer. + /// @param name - Optional name for the layer. + virtual void VisitLogicalBinaryLayer(const IConnectableLayer* layer, + const LogicalBinaryDescriptor& logicalBinaryDescriptor, + const char* name = nullptr) = 0; + /// Function an Lstm layer should call back to when its Accept(ILayerVisitor&) function is invoked. /// @param layer - pointer to the layer which is calling back to this visit function. /// @param descriptor - Parameters controlling the operation of the Lstm operation. diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp index 70ad94fa51..b89df63aa4 100644 --- a/include/armnn/INetwork.hpp +++ b/include/armnn/INetwork.hpp @@ -584,6 +584,13 @@ public: const LstmInputParams& params, const char* name = nullptr) = 0; + /// Adds a Logical Binary layer to the network. + /// @param descriptor - Description of the Logical Binary layer. + /// @param name - Optional name for the layer. + /// @return - Interface for configuring the layer. + virtual IConnectableLayer* AddLogicalBinaryLayer(const LogicalBinaryDescriptor& descriptor, + const char* name = nullptr) = 0; + virtual void Accept(ILayerVisitor& visitor) const = 0; protected: diff --git a/include/armnn/LayerVisitorBase.hpp b/include/armnn/LayerVisitorBase.hpp index 75237a4372..a97530ec40 100644 --- a/include/armnn/LayerVisitorBase.hpp +++ b/include/armnn/LayerVisitorBase.hpp @@ -141,6 +141,10 @@ public: const LogSoftmaxDescriptor&, const char*) override { DefaultPolicy::Apply(__func__); } + void VisitLogicalBinaryLayer(const IConnectableLayer*, + const LogicalBinaryDescriptor&, + const char*) override {DefaultPolicy::Apply(__func__); } + void VisitLstmLayer(const IConnectableLayer*, const LstmDescriptor&, const LstmInputParams&, diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp index 4a01549a14..46f246f46d 100644 --- a/include/armnn/Types.hpp +++ b/include/armnn/Types.hpp @@ -85,13 +85,20 @@ enum class ComparisonOperation NotEqual = 5 }; +enum class LogicalBinaryOperation +{ + LogicalAnd = 0, + LogicalOr = 1 +}; + enum class UnaryOperation { - Abs = 0, - Exp = 1, - Sqrt = 2, - Rsqrt = 3, - Neg = 4 + Abs = 0, + Exp = 1, + Sqrt = 2, + Rsqrt = 3, + Neg = 4, + LogicalNot = 5 }; enum class PoolingAlgorithm diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp index efc69deb67..1012fcfa22 100644 --- a/include/armnn/TypesUtils.hpp +++ b/include/armnn/TypesUtils.hpp @@ -72,12 +72,23 @@ constexpr char const* GetUnaryOperationAsCString(UnaryOperation operation) { switch (operation) { - case UnaryOperation::Abs: return "Abs"; - case UnaryOperation::Exp: return "Exp"; - case UnaryOperation::Sqrt: return "Sqrt"; - case UnaryOperation::Rsqrt: return "Rsqrt"; - case UnaryOperation::Neg: return "Neg"; - default: return "Unknown"; + case UnaryOperation::Abs: return "Abs"; + case UnaryOperation::Exp: return "Exp"; + case UnaryOperation::Sqrt: return "Sqrt"; + case UnaryOperation::Rsqrt: return "Rsqrt"; + case UnaryOperation::Neg: return "Neg"; + case UnaryOperation::LogicalNot: return "LogicalNot"; + default: return "Unknown"; + } +} + +constexpr char const* GetLogicalBinaryOperationAsCString(LogicalBinaryOperation operation) +{ + switch (operation) + { + case LogicalBinaryOperation::LogicalAnd: return "LogicalAnd"; + case LogicalBinaryOperation::LogicalOr: return "LogicalOr"; + default: return "Unknown"; } } |