aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-10-16 17:45:38 +0100
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>2019-10-21 08:52:04 +0000
commit77bfb5e32faadb1383d48364a6f54adbff84ad80 (patch)
tree0bf5dfb48cb8d5c248baf716f02b9f481400316e /include
parent5884708e650a80e355398532bc320bbabdbb53f4 (diff)
downloadarmnn-77bfb5e32faadb1383d48364a6f54adbff84ad80.tar.gz
IVGCVSW-3993 Add frontend and reference workload for ComparisonLayer
* Added frontend for ComparisonLayer * Added RefComparisonWorkload * Deprecated and removed Equal and Greater layers and workloads * Updated tests to ensure backward compatibility Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> Change-Id: Id50c880be1b567c531efff919c0c366d0a71cbe9
Diffstat (limited to 'include')
-rw-r--r--include/armnn/Descriptors.hpp20
-rw-r--r--include/armnn/DescriptorsFwd.hpp1
-rw-r--r--include/armnn/ILayerSupport.hpp8
-rw-r--r--include/armnn/ILayerVisitor.hpp10
-rw-r--r--include/armnn/INetwork.hpp9
-rw-r--r--include/armnn/LayerVisitorBase.hpp4
-rw-r--r--include/armnn/Types.hpp10
-rw-r--r--include/armnn/TypesUtils.hpp14
8 files changed, 76 insertions, 0 deletions
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index 92e842b2c1..10d8ab7a08 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -58,6 +58,26 @@ struct ArgMinMaxDescriptor
int m_Axis;
};
+/// A ComparisonDescriptor for the ComparisonLayer
+struct ComparisonDescriptor
+{
+ ComparisonDescriptor()
+ : ComparisonDescriptor(ComparisonOperation::Equal)
+ {}
+
+ ComparisonDescriptor(ComparisonOperation operation)
+ : m_Operation(operation)
+ {}
+
+ bool operator ==(const ComparisonDescriptor &rhs) const
+ {
+ return m_Operation == rhs.m_Operation;
+ }
+
+ /// Specifies the comparison operation to execute
+ ComparisonOperation m_Operation;
+};
+
/// A PermuteDescriptor for the PermuteLayer.
struct PermuteDescriptor
{
diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp
index 6f1c0e0a6e..a978f7739a 100644
--- a/include/armnn/DescriptorsFwd.hpp
+++ b/include/armnn/DescriptorsFwd.hpp
@@ -12,6 +12,7 @@ struct ActivationDescriptor;
struct ArgMinMaxDescriptor;
struct BatchNormalizationDescriptor;
struct BatchToSpaceNdDescriptor;
+struct ComparisonDescriptor;
struct Convolution2dDescriptor;
struct DepthwiseConvolution2dDescriptor;
struct DetectionPostProcessDescriptor;
diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp
index 31b5e134e9..87197eed5a 100644
--- a/include/armnn/ILayerSupport.hpp
+++ b/include/armnn/ILayerSupport.hpp
@@ -61,6 +61,12 @@ public:
const BatchToSpaceNdDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+ virtual bool IsComparisonSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ const ComparisonDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+
virtual bool IsConcatSupported(const std::vector<const TensorInfo*> inputs,
const TensorInfo& output,
const OriginsDescriptor& descriptor,
@@ -124,6 +130,7 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+ ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead")
virtual bool IsEqualSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
@@ -149,6 +156,7 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+ ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead")
virtual bool IsGreaterSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& ouput,
diff --git a/include/armnn/ILayerVisitor.hpp b/include/armnn/ILayerVisitor.hpp
index e99e10f800..80931ebcc0 100644
--- a/include/armnn/ILayerVisitor.hpp
+++ b/include/armnn/ILayerVisitor.hpp
@@ -74,6 +74,14 @@ public:
const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor,
const char* name = nullptr) = 0;
+ /// Function a Comparison layer should call back to when its Accept(ILayerVisitor&) function is invoked.
+ /// @param layer - pointer to the layer which is calling back to this visit function.
+ /// @param comparisonDescriptor - Description of the layer.
+ /// @param name - Optional name for the layer.
+ virtual void VisitComparisonLayer(const IConnectableLayer* layer,
+ const ComparisonDescriptor& comparisonDescriptor,
+ const char* name = nullptr) = 0;
+
/// Function that a concat layer should call back to when its Accept(ILayerVisitor&) function is invoked.
/// @param layer - pointer to the layer which is calling back to this visit function.
/// @param concatDescriptor - ConcatDescriptor (synonym for OriginsDescriptor) to configure the concatenation
@@ -163,6 +171,7 @@ public:
/// Function an Equal layer should call back to when its Accept(ILayerVisitor&) function is invoked.
/// @param layer - pointer to the layer which is calling back to this visit function.
/// @param name - Optional name for the layer.
+ ARMNN_DEPRECATED_MSG("Use VisitComparisonLayer instead")
virtual void VisitEqualLayer(const IConnectableLayer* layer,
const char* name = nullptr) = 0;
@@ -194,6 +203,7 @@ public:
/// Function a Greater layer should call back to when its Accept(ILayerVisitor&) function is invoked.
/// @param layer - pointer to the layer which is calling back to this visit function.
/// @param name - Optional name for the layer.
+ ARMNN_DEPRECATED_MSG("Use VisitComparisonLayer instead")
virtual void VisitGreaterLayer(const IConnectableLayer* layer,
const char* name = nullptr) = 0;
diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp
index d12f5c239c..b3fab82cb4 100644
--- a/include/armnn/INetwork.hpp
+++ b/include/armnn/INetwork.hpp
@@ -109,6 +109,13 @@ public:
virtual IConnectableLayer* AddArgMinMaxLayer(const ArgMinMaxDescriptor& desc,
const char* name = nullptr) = 0;
+ /// Add a Comparison layer to the network.
+ /// @param name - Optional name for the layer.
+ /// @param desc - Descriptor for the comparison operation.
+ /// @ return - Interface for configuring the layer.
+ virtual IConnectableLayer* AddComparisonLayer(const ComparisonDescriptor& comparisonDescriptor,
+ const char* name = nullptr) = 0;
+
/// Adds a concatenation layer to the network.
/// @param concatDescriptor - ConcatDescriptor (synonym for OriginsDescriptor) to configure the concatenation
/// process. Number of Views must be equal to the number of inputs, and their order
@@ -453,11 +460,13 @@ public:
/// Add a Greater layer to the network.
/// @param name - Optional name for the layer.
/// @ return - Interface for configuring the layer.
+ ARMNN_DEPRECATED_MSG("Use AddComparisonLayer instead")
virtual IConnectableLayer* AddGreaterLayer(const char* name = nullptr) = 0;
/// Add a Equal layer to the network.
/// @param name - Optional name for the layer.
/// @ return - Interface for configuring the layer.
+ ARMNN_DEPRECATED_MSG("Use AddComparisonLayer instead")
virtual IConnectableLayer* AddEqualLayer(const char* name = nullptr) = 0;
/// Add Reciprocal of square root layer to the network.
diff --git a/include/armnn/LayerVisitorBase.hpp b/include/armnn/LayerVisitorBase.hpp
index 912f25500c..5226fa2f66 100644
--- a/include/armnn/LayerVisitorBase.hpp
+++ b/include/armnn/LayerVisitorBase.hpp
@@ -55,6 +55,10 @@ public:
const BatchToSpaceNdDescriptor&,
const char*) override { DefaultPolicy::Apply(__func__); }
+ void VisitComparisonLayer(const IConnectableLayer*,
+ const ComparisonDescriptor&,
+ const char*) override { DefaultPolicy::Apply(__func__); }
+
void VisitConcatLayer(const IConnectableLayer*,
const ConcatDescriptor&,
const char*) override { DefaultPolicy::Apply(__func__); }
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp
index dbcb91aebe..16a148c9c2 100644
--- a/include/armnn/Types.hpp
+++ b/include/armnn/Types.hpp
@@ -60,6 +60,16 @@ enum class ArgMinMaxFunction
Max = 1
};
+enum class ComparisonOperation
+{
+ Equal = 0,
+ Greater = 1,
+ GreaterOrEqual = 2,
+ Less = 3,
+ LessOrEqual = 4,
+ NotEqual = 5
+};
+
enum class PoolingAlgorithm
{
Max = 0,
diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp
index cb52471cd5..310792c2a4 100644
--- a/include/armnn/TypesUtils.hpp
+++ b/include/armnn/TypesUtils.hpp
@@ -42,6 +42,20 @@ constexpr char const* GetActivationFunctionAsCString(ActivationFunction activati
}
}
+constexpr char const* GetComparisonOperationAsCString(ComparisonOperation operation)
+{
+ switch (operation)
+ {
+ case ComparisonOperation::Equal: return "Equal";
+ case ComparisonOperation::Greater: return "Greater";
+ case ComparisonOperation::GreaterOrEqual: return "GreaterOrEqual";
+ case ComparisonOperation::Less: return "Less";
+ case ComparisonOperation::LessOrEqual: return "LessOrEqual";
+ case ComparisonOperation::NotEqual: return "NotEqual";
+ default: return "Unknown";
+ }
+}
+
constexpr char const* GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling)
{
switch (pooling)