aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-10-16 17:45:38 +0100
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>2019-10-21 08:52:04 +0000
commit77bfb5e32faadb1383d48364a6f54adbff84ad80 (patch)
tree0bf5dfb48cb8d5c248baf716f02b9f481400316e
parent5884708e650a80e355398532bc320bbabdbb53f4 (diff)
downloadarmnn-77bfb5e32faadb1383d48364a6f54adbff84ad80.tar.gz
IVGCVSW-3993 Add frontend and reference workload for ComparisonLayer
* Added frontend for ComparisonLayer * Added RefComparisonWorkload * Deprecated and removed Equal and Greater layers and workloads * Updated tests to ensure backward compatibility Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> Change-Id: Id50c880be1b567c531efff919c0c366d0a71cbe9
-rw-r--r--Android.mk3
-rw-r--r--CMakeLists.txt6
-rw-r--r--include/armnn/Descriptors.hpp20
-rw-r--r--include/armnn/DescriptorsFwd.hpp1
-rw-r--r--include/armnn/ILayerSupport.hpp8
-rw-r--r--include/armnn/ILayerVisitor.hpp10
-rw-r--r--include/armnn/INetwork.hpp9
-rw-r--r--include/armnn/LayerVisitorBase.hpp4
-rw-r--r--include/armnn/Types.hpp10
-rw-r--r--include/armnn/TypesUtils.hpp14
-rw-r--r--src/armnn/InternalTypes.cpp3
-rw-r--r--src/armnn/InternalTypes.hpp3
-rw-r--r--src/armnn/LayerSupport.cpp14
-rw-r--r--src/armnn/LayersFwd.hpp6
-rw-r--r--src/armnn/Network.cpp10
-rw-r--r--src/armnn/Network.hpp5
-rw-r--r--src/armnn/layers/ComparisonLayer.cpp80
-rw-r--r--src/armnn/layers/ComparisonLayer.hpp50
-rw-r--r--src/armnn/layers/EqualLayer.cpp39
-rw-r--r--src/armnn/layers/EqualLayer.hpp38
-rw-r--r--src/armnn/layers/GreaterLayer.cpp39
-rw-r--r--src/armnn/layers/GreaterLayer.hpp39
-rw-r--r--src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp9
-rw-r--r--src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp1
-rw-r--r--src/armnn/test/TestNameOnlyLayerVisitor.cpp2
-rw-r--r--src/armnn/test/TestNameOnlyLayerVisitor.hpp2
-rw-r--r--src/armnnDeserializer/Deserializer.cpp6
-rw-r--r--src/armnnSerializer/Serializer.cpp7
-rw-r--r--src/armnnSerializer/Serializer.hpp6
-rw-r--r--src/armnnSerializer/test/SerializerTests.cpp84
-rwxr-xr-xsrc/armnnTfParser/TfParser.cpp6
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.cpp9
-rw-r--r--src/backends/backendsCommon/LayerSupportBase.hpp8
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp24
-rw-r--r--src/backends/backendsCommon/WorkloadData.hpp5
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.cpp43
-rw-r--r--src/backends/backendsCommon/WorkloadFactory.hpp5
-rw-r--r--src/backends/backendsCommon/test/ArithmeticTestImpl.hpp113
-rw-r--r--src/backends/backendsCommon/test/CMakeLists.txt2
-rw-r--r--src/backends/backendsCommon/test/ComparisonEndToEndTestImpl.hpp103
-rw-r--r--src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp6
-rw-r--r--src/backends/backendsCommon/test/layerTests/ComparisonTestImpl.hpp126
-rw-r--r--src/backends/backendsCommon/test/layerTests/EqualTestImpl.cpp57
-rw-r--r--src/backends/backendsCommon/test/layerTests/GreaterTestImpl.cpp68
-rw-r--r--src/backends/cl/ClLayerSupport.cpp25
-rw-r--r--src/backends/cl/ClLayerSupport.hpp7
-rw-r--r--src/backends/cl/ClWorkloadFactory.cpp28
-rw-r--r--src/backends/cl/ClWorkloadFactory.hpp3
-rw-r--r--src/backends/cl/test/ClEndToEndTests.cpp26
-rw-r--r--src/backends/neon/NeonLayerSupport.cpp25
-rw-r--r--src/backends/neon/NeonLayerSupport.hpp7
-rw-r--r--src/backends/neon/NeonWorkloadFactory.cpp28
-rw-r--r--src/backends/neon/NeonWorkloadFactory.hpp5
-rw-r--r--src/backends/neon/test/NeonEndToEndTests.cpp26
-rw-r--r--src/backends/reference/RefLayerSupport.cpp85
-rw-r--r--src/backends/reference/RefLayerSupport.hpp8
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp16
-rw-r--r--src/backends/reference/RefWorkloadFactory.hpp5
-rw-r--r--src/backends/reference/backend.mk1
-rw-r--r--src/backends/reference/test/RefEndToEndTests.cpp50
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt2
-rw-r--r--src/backends/reference/workloads/ElementwiseFunction.cpp7
-rw-r--r--src/backends/reference/workloads/RefComparisonWorkload.cpp102
-rw-r--r--src/backends/reference/workloads/RefComparisonWorkload.hpp34
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.cpp8
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.hpp9
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp1
-rw-r--r--src/backends/reference/workloads/StringMapping.hpp4
68 files changed, 991 insertions, 624 deletions
diff --git a/Android.mk b/Android.mk
index 81a77e21bc..02941230f6 100644
--- a/Android.mk
+++ b/Android.mk
@@ -118,6 +118,7 @@ LOCAL_SRC_FILES := \
src/armnn/layers/ArgMinMaxLayer.cpp \
src/armnn/layers/BatchNormalizationLayer.cpp \
src/armnn/layers/BatchToSpaceNdLayer.cpp \
+ src/armnn/layers/ComparisonLayer.cpp \
src/armnn/layers/ConcatLayer.cpp \
src/armnn/layers/ConstantLayer.cpp \
src/armnn/layers/Convolution2dLayer.cpp \
@@ -130,12 +131,10 @@ LOCAL_SRC_FILES := \
src/armnn/layers/DetectionPostProcessLayer.cpp \
src/armnn/layers/DivisionLayer.cpp \
src/armnn/layers/ElementwiseBaseLayer.cpp \
- src/armnn/layers/EqualLayer.cpp \
src/armnn/layers/FakeQuantizationLayer.cpp \
src/armnn/layers/FloorLayer.cpp \
src/armnn/layers/FullyConnectedLayer.cpp \
src/armnn/layers/GatherLayer.cpp \
- src/armnn/layers/GreaterLayer.cpp \
src/armnn/layers/InputLayer.cpp \
src/armnn/layers/InstanceNormalizationLayer.cpp \
src/armnn/layers/L2NormalizationLayer.cpp \
diff --git a/CMakeLists.txt b/CMakeLists.txt
index e2712dd55a..626478a0e9 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -250,6 +250,8 @@ list(APPEND armnn_sources
src/armnn/layers/BatchNormalizationLayer.cpp
src/armnn/layers/BatchToSpaceNdLayer.hpp
src/armnn/layers/BatchToSpaceNdLayer.cpp
+ src/armnn/layers/ComparisonLayer.hpp
+ src/armnn/layers/ComparisonLayer.cpp
src/armnn/layers/ConcatLayer.hpp
src/armnn/layers/ConcatLayer.cpp
src/armnn/layers/ConstantLayer.hpp
@@ -272,8 +274,6 @@ list(APPEND armnn_sources
src/armnn/layers/DetectionPostProcessLayer.cpp
src/armnn/layers/ElementwiseBaseLayer.hpp
src/armnn/layers/ElementwiseBaseLayer.cpp
- src/armnn/layers/EqualLayer.hpp
- src/armnn/layers/EqualLayer.cpp
src/armnn/layers/FakeQuantizationLayer.hpp
src/armnn/layers/FakeQuantizationLayer.cpp
src/armnn/layers/FloorLayer.hpp
@@ -282,8 +282,6 @@ list(APPEND armnn_sources
src/armnn/layers/FullyConnectedLayer.cpp
src/armnn/layers/GatherLayer.cpp
src/armnn/layers/GatherLayer.hpp
- src/armnn/layers/GreaterLayer.cpp
- src/armnn/layers/GreaterLayer.hpp
src/armnn/layers/InputLayer.hpp
src/armnn/layers/InputLayer.cpp
src/armnn/layers/InstanceNormalizationLayer.hpp
diff --git a/include/armnn/Descriptors.hpp b/include/armnn/Descriptors.hpp
index 92e842b2c1..10d8ab7a08 100644
--- a/include/armnn/Descriptors.hpp
+++ b/include/armnn/Descriptors.hpp
@@ -58,6 +58,26 @@ struct ArgMinMaxDescriptor
int m_Axis;
};
+/// A ComparisonDescriptor for the ComparisonLayer
+struct ComparisonDescriptor
+{
+ ComparisonDescriptor()
+ : ComparisonDescriptor(ComparisonOperation::Equal)
+ {}
+
+ ComparisonDescriptor(ComparisonOperation operation)
+ : m_Operation(operation)
+ {}
+
+ bool operator ==(const ComparisonDescriptor &rhs) const
+ {
+ return m_Operation == rhs.m_Operation;
+ }
+
+ /// Specifies the comparison operation to execute
+ ComparisonOperation m_Operation;
+};
+
/// A PermuteDescriptor for the PermuteLayer.
struct PermuteDescriptor
{
diff --git a/include/armnn/DescriptorsFwd.hpp b/include/armnn/DescriptorsFwd.hpp
index 6f1c0e0a6e..a978f7739a 100644
--- a/include/armnn/DescriptorsFwd.hpp
+++ b/include/armnn/DescriptorsFwd.hpp
@@ -12,6 +12,7 @@ struct ActivationDescriptor;
struct ArgMinMaxDescriptor;
struct BatchNormalizationDescriptor;
struct BatchToSpaceNdDescriptor;
+struct ComparisonDescriptor;
struct Convolution2dDescriptor;
struct DepthwiseConvolution2dDescriptor;
struct DetectionPostProcessDescriptor;
diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp
index 31b5e134e9..87197eed5a 100644
--- a/include/armnn/ILayerSupport.hpp
+++ b/include/armnn/ILayerSupport.hpp
@@ -61,6 +61,12 @@ public:
const BatchToSpaceNdDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+ virtual bool IsComparisonSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ const ComparisonDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+
virtual bool IsConcatSupported(const std::vector<const TensorInfo*> inputs,
const TensorInfo& output,
const OriginsDescriptor& descriptor,
@@ -124,6 +130,7 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+ ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead")
virtual bool IsEqualSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
@@ -149,6 +156,7 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const = 0;
+ ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead")
virtual bool IsGreaterSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& ouput,
diff --git a/include/armnn/ILayerVisitor.hpp b/include/armnn/ILayerVisitor.hpp
index e99e10f800..80931ebcc0 100644
--- a/include/armnn/ILayerVisitor.hpp
+++ b/include/armnn/ILayerVisitor.hpp
@@ -74,6 +74,14 @@ public:
const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor,
const char* name = nullptr) = 0;
+ /// Function a Comparison layer should call back to when its Accept(ILayerVisitor&) function is invoked.
+ /// @param layer - pointer to the layer which is calling back to this visit function.
+ /// @param comparisonDescriptor - Description of the layer.
+ /// @param name - Optional name for the layer.
+ virtual void VisitComparisonLayer(const IConnectableLayer* layer,
+ const ComparisonDescriptor& comparisonDescriptor,
+ const char* name = nullptr) = 0;
+
/// Function that a concat layer should call back to when its Accept(ILayerVisitor&) function is invoked.
/// @param layer - pointer to the layer which is calling back to this visit function.
/// @param concatDescriptor - ConcatDescriptor (synonym for OriginsDescriptor) to configure the concatenation
@@ -163,6 +171,7 @@ public:
/// Function an Equal layer should call back to when its Accept(ILayerVisitor&) function is invoked.
/// @param layer - pointer to the layer which is calling back to this visit function.
/// @param name - Optional name for the layer.
+ ARMNN_DEPRECATED_MSG("Use VisitComparisonLayer instead")
virtual void VisitEqualLayer(const IConnectableLayer* layer,
const char* name = nullptr) = 0;
@@ -194,6 +203,7 @@ public:
/// Function a Greater layer should call back to when its Accept(ILayerVisitor&) function is invoked.
/// @param layer - pointer to the layer which is calling back to this visit function.
/// @param name - Optional name for the layer.
+ ARMNN_DEPRECATED_MSG("Use VisitComparisonLayer instead")
virtual void VisitGreaterLayer(const IConnectableLayer* layer,
const char* name = nullptr) = 0;
diff --git a/include/armnn/INetwork.hpp b/include/armnn/INetwork.hpp
index d12f5c239c..b3fab82cb4 100644
--- a/include/armnn/INetwork.hpp
+++ b/include/armnn/INetwork.hpp
@@ -109,6 +109,13 @@ public:
virtual IConnectableLayer* AddArgMinMaxLayer(const ArgMinMaxDescriptor& desc,
const char* name = nullptr) = 0;
+ /// Add a Comparison layer to the network.
+ /// @param name - Optional name for the layer.
+ /// @param desc - Descriptor for the comparison operation.
+ /// @ return - Interface for configuring the layer.
+ virtual IConnectableLayer* AddComparisonLayer(const ComparisonDescriptor& comparisonDescriptor,
+ const char* name = nullptr) = 0;
+
/// Adds a concatenation layer to the network.
/// @param concatDescriptor - ConcatDescriptor (synonym for OriginsDescriptor) to configure the concatenation
/// process. Number of Views must be equal to the number of inputs, and their order
@@ -453,11 +460,13 @@ public:
/// Add a Greater layer to the network.
/// @param name - Optional name for the layer.
/// @ return - Interface for configuring the layer.
+ ARMNN_DEPRECATED_MSG("Use AddComparisonLayer instead")
virtual IConnectableLayer* AddGreaterLayer(const char* name = nullptr) = 0;
/// Add a Equal layer to the network.
/// @param name - Optional name for the layer.
/// @ return - Interface for configuring the layer.
+ ARMNN_DEPRECATED_MSG("Use AddComparisonLayer instead")
virtual IConnectableLayer* AddEqualLayer(const char* name = nullptr) = 0;
/// Add Reciprocal of square root layer to the network.
diff --git a/include/armnn/LayerVisitorBase.hpp b/include/armnn/LayerVisitorBase.hpp
index 912f25500c..5226fa2f66 100644
--- a/include/armnn/LayerVisitorBase.hpp
+++ b/include/armnn/LayerVisitorBase.hpp
@@ -55,6 +55,10 @@ public:
const BatchToSpaceNdDescriptor&,
const char*) override { DefaultPolicy::Apply(__func__); }
+ void VisitComparisonLayer(const IConnectableLayer*,
+ const ComparisonDescriptor&,
+ const char*) override { DefaultPolicy::Apply(__func__); }
+
void VisitConcatLayer(const IConnectableLayer*,
const ConcatDescriptor&,
const char*) override { DefaultPolicy::Apply(__func__); }
diff --git a/include/armnn/Types.hpp b/include/armnn/Types.hpp
index dbcb91aebe..16a148c9c2 100644
--- a/include/armnn/Types.hpp
+++ b/include/armnn/Types.hpp
@@ -60,6 +60,16 @@ enum class ArgMinMaxFunction
Max = 1
};
+enum class ComparisonOperation
+{
+ Equal = 0,
+ Greater = 1,
+ GreaterOrEqual = 2,
+ Less = 3,
+ LessOrEqual = 4,
+ NotEqual = 5
+};
+
enum class PoolingAlgorithm
{
Max = 0,
diff --git a/include/armnn/TypesUtils.hpp b/include/armnn/TypesUtils.hpp
index cb52471cd5..310792c2a4 100644
--- a/include/armnn/TypesUtils.hpp
+++ b/include/armnn/TypesUtils.hpp
@@ -42,6 +42,20 @@ constexpr char const* GetActivationFunctionAsCString(ActivationFunction activati
}
}
+constexpr char const* GetComparisonOperationAsCString(ComparisonOperation operation)
+{
+ switch (operation)
+ {
+ case ComparisonOperation::Equal: return "Equal";
+ case ComparisonOperation::Greater: return "Greater";
+ case ComparisonOperation::GreaterOrEqual: return "GreaterOrEqual";
+ case ComparisonOperation::Less: return "Less";
+ case ComparisonOperation::LessOrEqual: return "LessOrEqual";
+ case ComparisonOperation::NotEqual: return "NotEqual";
+ default: return "Unknown";
+ }
+}
+
constexpr char const* GetPoolingAlgorithmAsCString(PoolingAlgorithm pooling)
{
switch (pooling)
diff --git a/src/armnn/InternalTypes.cpp b/src/armnn/InternalTypes.cpp
index 7c39128bec..f713644656 100644
--- a/src/armnn/InternalTypes.cpp
+++ b/src/armnn/InternalTypes.cpp
@@ -20,6 +20,7 @@ char const* GetLayerTypeAsCString(LayerType type)
case LayerType::ArgMinMax: return "ArgMinMax";
case LayerType::BatchNormalization: return "BatchNormalization";
case LayerType::BatchToSpaceNd: return "BatchToSpaceNd";
+ case LayerType::Comparison: return "Comparison";
case LayerType::Concat: return "Concat";
case LayerType::Constant: return "Constant";
case LayerType::ConvertFp16ToFp32: return "ConvertFp16ToFp32";
@@ -31,12 +32,10 @@ char const* GetLayerTypeAsCString(LayerType type)
case LayerType::Dequantize: return "Dequantize";
case LayerType::DetectionPostProcess: return "DetectionPostProcess";
case LayerType::Division: return "Division";
- case LayerType::Equal: return "Equal";
case LayerType::FakeQuantization: return "FakeQuantization";
case LayerType::Floor: return "Floor";
case LayerType::FullyConnected: return "FullyConnected";
case LayerType::Gather: return "Gather";
- case LayerType::Greater: return "Greater";
case LayerType::Input: return "Input";
case LayerType::InstanceNormalization: return "InstanceNormalization";
case LayerType::L2Normalization: return "L2Normalization";
diff --git a/src/armnn/InternalTypes.hpp b/src/armnn/InternalTypes.hpp
index 895fe3235d..d7f932f699 100644
--- a/src/armnn/InternalTypes.hpp
+++ b/src/armnn/InternalTypes.hpp
@@ -20,6 +20,7 @@ enum class LayerType
ArgMinMax,
BatchNormalization,
BatchToSpaceNd,
+ Comparison,
Concat,
Constant,
ConvertFp16ToFp32,
@@ -31,12 +32,10 @@ enum class LayerType
Dequantize,
DetectionPostProcess,
Division,
- Equal,
FakeQuantization,
Floor,
FullyConnected,
Gather,
- Greater,
Input,
InstanceNormalization,
L2Normalization,
diff --git a/src/armnn/LayerSupport.cpp b/src/armnn/LayerSupport.cpp
index f88e4e1cc9..997b5f245a 100644
--- a/src/armnn/LayerSupport.cpp
+++ b/src/armnn/LayerSupport.cpp
@@ -262,7 +262,12 @@ bool IsEqualSupported(const BackendId& backend,
char* reasonIfUnsupported,
size_t reasonIfUnsupportedMaxLength)
{
- FORWARD_LAYER_SUPPORT_FUNC(backend, IsEqualSupported, input0, input1, output);
+ FORWARD_LAYER_SUPPORT_FUNC(backend,
+ IsComparisonSupported,
+ input0,
+ input1,
+ output,
+ ComparisonDescriptor(ComparisonOperation::Equal));
}
bool IsFakeQuantizationSupported(const BackendId& backend,
@@ -317,7 +322,12 @@ bool IsGreaterSupported(const BackendId& backend,
char* reasonIfUnsupported,
size_t reasonIfUnsupportedMaxLength)
{
- FORWARD_LAYER_SUPPORT_FUNC(backend, IsGreaterSupported, input0, input1, output);
+ FORWARD_LAYER_SUPPORT_FUNC(backend,
+ IsComparisonSupported,
+ input0,
+ input1,
+ output,
+ ComparisonDescriptor(ComparisonOperation::Greater));
}
bool IsInputSupported(const BackendId& backend,
diff --git a/src/armnn/LayersFwd.hpp b/src/armnn/LayersFwd.hpp
index 7bb9c64818..6c30749904 100644
--- a/src/armnn/LayersFwd.hpp
+++ b/src/armnn/LayersFwd.hpp
@@ -12,6 +12,7 @@
#include "layers/ArgMinMaxLayer.hpp"
#include "layers/BatchNormalizationLayer.hpp"
#include "layers/BatchToSpaceNdLayer.hpp"
+#include "layers/ComparisonLayer.hpp"
#include "layers/ConcatLayer.hpp"
#include "layers/ConstantLayer.hpp"
#include "layers/ConvertFp16ToFp32Layer.hpp"
@@ -23,12 +24,10 @@
#include "layers/DequantizeLayer.hpp"
#include "layers/DetectionPostProcessLayer.hpp"
#include "layers/DivisionLayer.hpp"
-#include "layers/EqualLayer.hpp"
#include "layers/FakeQuantizationLayer.hpp"
#include "layers/FloorLayer.hpp"
#include "layers/FullyConnectedLayer.hpp"
#include "layers/GatherLayer.hpp"
-#include "layers/GreaterLayer.hpp"
#include "layers/InputLayer.hpp"
#include "layers/InstanceNormalizationLayer.hpp"
#include "layers/L2NormalizationLayer.hpp"
@@ -97,6 +96,7 @@ DECLARE_LAYER(Addition)
DECLARE_LAYER(ArgMinMax)
DECLARE_LAYER(BatchNormalization)
DECLARE_LAYER(BatchToSpaceNd)
+DECLARE_LAYER(Comparison)
DECLARE_LAYER(Concat)
DECLARE_LAYER(Constant)
DECLARE_LAYER(ConvertFp16ToFp32)
@@ -108,12 +108,10 @@ DECLARE_LAYER(DepthwiseConvolution2d)
DECLARE_LAYER(Dequantize)
DECLARE_LAYER(DetectionPostProcess)
DECLARE_LAYER(Division)
-DECLARE_LAYER(Equal)
DECLARE_LAYER(FakeQuantization)
DECLARE_LAYER(Floor)
DECLARE_LAYER(FullyConnected)
DECLARE_LAYER(Gather)
-DECLARE_LAYER(Greater)
DECLARE_LAYER(Input)
DECLARE_LAYER(InstanceNormalization)
DECLARE_LAYER(L2Normalization)
diff --git a/src/armnn/Network.cpp b/src/armnn/Network.cpp
index b2fc1a6389..857f6b3959 100644
--- a/src/armnn/Network.cpp
+++ b/src/armnn/Network.cpp
@@ -938,6 +938,12 @@ IConnectableLayer* Network::AddBatchToSpaceNdLayer(const BatchToSpaceNdDescripto
return m_Graph->AddLayer<BatchToSpaceNdLayer>(batchToSpaceNdDescriptor, name);
}
+IConnectableLayer* Network::AddComparisonLayer(const ComparisonDescriptor& comparisonDescriptor,
+ const char* name)
+{
+ return m_Graph->AddLayer<ComparisonLayer>(comparisonDescriptor, name);
+}
+
IConnectableLayer* Network::AddFullyConnectedLayerImpl(const FullyConnectedDescriptor& fullyConnectedDescriptor,
const ConstTensor& weights,
const Optional<ConstTensor>& biases,
@@ -1436,12 +1442,12 @@ IConnectableLayer* Network::AddStridedSliceLayer(const StridedSliceDescriptor& s
IConnectableLayer* Network::AddGreaterLayer(const char* name)
{
- return m_Graph->AddLayer<GreaterLayer>(name);
+ return AddComparisonLayer(ComparisonDescriptor(ComparisonOperation::Greater), name);
}
IConnectableLayer* Network::AddEqualLayer(const char* name)
{
- return m_Graph->AddLayer<EqualLayer>(name);
+ return AddComparisonLayer(ComparisonDescriptor(ComparisonOperation::Equal), name);
}
IConnectableLayer* Network::AddRsqrtLayer(const char * name)
diff --git a/src/armnn/Network.hpp b/src/armnn/Network.hpp
index ad1e7c456e..c1d99a9010 100644
--- a/src/armnn/Network.hpp
+++ b/src/armnn/Network.hpp
@@ -42,6 +42,9 @@ public:
IConnectableLayer* AddBatchToSpaceNdLayer(const BatchToSpaceNdDescriptor& batchToSpaceNdDescriptor,
const char* name = nullptr) override;
+ IConnectableLayer* AddComparisonLayer(const ComparisonDescriptor& comparisonDescriptor,
+ const char* name = nullptr) override;
+
IConnectableLayer* AddConcatLayer(const ConcatDescriptor& concatDescriptor,
const char* name = nullptr) override;
@@ -197,8 +200,10 @@ public:
IConnectableLayer* AddMinimumLayer(const char* name = nullptr) override;
+ ARMNN_DEPRECATED_MSG("Use AddComparisonLayer instead")
IConnectableLayer* AddGreaterLayer(const char* name = nullptr) override;
+ ARMNN_DEPRECATED_MSG("Use AddComparisonLayer instead")
IConnectableLayer* AddEqualLayer(const char* name = nullptr) override;
IConnectableLayer* AddRsqrtLayer(const char* name = nullptr) override;
diff --git a/src/armnn/layers/ComparisonLayer.cpp b/src/armnn/layers/ComparisonLayer.cpp
new file mode 100644
index 0000000000..75518e580e
--- /dev/null
+++ b/src/armnn/layers/ComparisonLayer.cpp
@@ -0,0 +1,80 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ComparisonLayer.hpp"
+
+#include "LayerCloneBase.hpp"
+
+#include <backendsCommon/WorkloadData.hpp>
+#include <backendsCommon/WorkloadFactory.hpp>
+
+#include <algorithm>
+
+namespace armnn
+{
+
+ComparisonLayer::ComparisonLayer(const ComparisonDescriptor& param, const char* name)
+ : LayerWithParameters(2, 1, LayerType::Comparison, param, name)
+{
+}
+
+std::unique_ptr<IWorkload> ComparisonLayer::CreateWorkload(const Graph& graph,
+ const IWorkloadFactory& factory) const
+{
+ ComparisonQueueDescriptor descriptor;
+ return factory.CreateComparison(descriptor, PrepInfoAndDesc(descriptor, graph));
+}
+
+ComparisonLayer* ComparisonLayer::Clone(Graph& graph) const
+{
+ return CloneBase<ComparisonLayer>(graph, m_Param, GetName());
+}
+
+std::vector<TensorShape> ComparisonLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
+{
+ BOOST_ASSERT(inputShapes.size() == 2);
+ const TensorShape& input0 = inputShapes[0];
+ const TensorShape& input1 = inputShapes[1];
+
+ BOOST_ASSERT(input0.GetNumDimensions() == input1.GetNumDimensions());
+ unsigned int numDims = input0.GetNumDimensions();
+
+ std::vector<unsigned int> dims(numDims);
+ for (unsigned int i = 0; i < numDims; i++)
+ {
+ unsigned int dim0 = input0[i];
+ unsigned int dim1 = input1[i];
+
+ BOOST_ASSERT_MSG(dim0 == dim1 || dim0 == 1 || dim1 == 1,
+ "Dimensions should either match or one should be of size 1.");
+
+ dims[i] = std::max(dim0, dim1);
+ }
+
+ return std::vector<TensorShape>({ TensorShape(numDims, dims.data()) });
+}
+
+void ComparisonLayer::ValidateTensorShapesFromInputs()
+{
+ VerifyLayerConnections(2, CHECK_LOCATION());
+
+ std::vector<TensorShape> inferredShapes = InferOutputShapes({
+ GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
+ GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape()
+ });
+ BOOST_ASSERT(inferredShapes.size() == 1);
+
+ ConditionalThrowIfNotEqual<LayerValidationException>(
+ "ComparisonLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.",
+ GetOutputSlot(0).GetTensorInfo().GetShape(),
+ inferredShapes[0]);
+}
+
+void ComparisonLayer::Accept(ILayerVisitor& visitor) const
+{
+ visitor.VisitComparisonLayer(this, GetParameters(), GetName());
+}
+
+} // namespace armnn
diff --git a/src/armnn/layers/ComparisonLayer.hpp b/src/armnn/layers/ComparisonLayer.hpp
new file mode 100644
index 0000000000..bbc2b573bf
--- /dev/null
+++ b/src/armnn/layers/ComparisonLayer.hpp
@@ -0,0 +1,50 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "LayerWithParameters.hpp"
+
+namespace armnn
+{
+
+/// This layer represents a comparison operation.
+class ComparisonLayer : public LayerWithParameters<ComparisonDescriptor>
+{
+public:
+ /// Makes a workload for the Comparison type
+ /// @param [in] graph The graph where this layer can be found
+ /// @param [in] factory The workload factory which will create the workload
+ /// @return A pointer to the created workload, or nullptr if not created
+ virtual std::unique_ptr<IWorkload> CreateWorkload(const Graph& graph,
+ const IWorkloadFactory& factory) const override;
+
+ /// Creates a dynamically-allocated copy of this layer
+ /// @param [in] graph The graph into which this layer is being cloned
+ ComparisonLayer* Clone(Graph& graph) const override;
+
+ /// By default returns inputShapes if the number of inputs are equal to number of outputs,
+ /// otherwise infers the output shapes from given input shapes and layer properties.
+ /// @param [in] inputShapes The input shapes layer has.
+ /// @return A vector to the inferred output shape.
+ std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const override;
+
+ /// Check if the input tensor shape(s) will lead to a valid configuration
+ /// of @ref ComparisonLayer
+ void ValidateTensorShapesFromInputs() override;
+
+ void Accept(ILayerVisitor& visitor) const override;
+
+protected:
+ /// Constructor to create a ComparisonLayer
+ /// @param [in] param ComparisonDescriptor to configure the ComparisonLayer
+ /// @param [in] name Optional name for the layer
+ ComparisonLayer(const ComparisonDescriptor& param, const char* name);
+
+ /// Default destructor
+ ~ComparisonLayer() = default;
+};
+
+} // namespace armnn
diff --git a/src/armnn/layers/EqualLayer.cpp b/src/armnn/layers/EqualLayer.cpp
deleted file mode 100644
index 7d16668b06..0000000000
--- a/src/armnn/layers/EqualLayer.cpp
+++ /dev/null
@@ -1,39 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "EqualLayer.hpp"
-
-#include "LayerCloneBase.hpp"
-
-#include <armnn/TypesUtils.hpp>
-#include <backendsCommon/WorkloadData.hpp>
-#include <backendsCommon/WorkloadFactory.hpp>
-
-namespace armnn
-{
-
-EqualLayer::EqualLayer(const char* name)
- : ElementwiseBaseLayer(2, 1, LayerType::Equal, name)
-{
-}
-
-std::unique_ptr<IWorkload> EqualLayer::CreateWorkload(const Graph& graph,
- const IWorkloadFactory& factory) const
-{
- EqualQueueDescriptor descriptor;
- return factory.CreateEqual(descriptor, PrepInfoAndDesc(descriptor, graph));
-}
-
-EqualLayer* EqualLayer::Clone(Graph& graph) const
-{
- return CloneBase<EqualLayer>(graph, GetName());
-}
-
-void EqualLayer::Accept(ILayerVisitor& visitor) const
-{
- visitor.VisitEqualLayer(this, GetName());
-}
-
-} // namespace armnn
diff --git a/src/armnn/layers/EqualLayer.hpp b/src/armnn/layers/EqualLayer.hpp
deleted file mode 100644
index b6a01eff2d..0000000000
--- a/src/armnn/layers/EqualLayer.hpp
+++ /dev/null
@@ -1,38 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include "ElementwiseBaseLayer.hpp"
-
-namespace armnn
-{
-/// This layer represents an equal operation.
-class EqualLayer : public ElementwiseBaseLayer
-{
-public:
- /// Makes a workload for the Equal type.
- /// @param [in] graph The graph where this layer can be found.
- /// @param [in] factory The workload factory which will create the workload.
- /// @return A pointer to the created workload, or nullptr if not created.
- virtual std::unique_ptr<IWorkload> CreateWorkload(const Graph& graph,
- const IWorkloadFactory& factory) const override;
-
- /// Creates a dynamically-allocated copy of this layer.
- /// @param [in] graph The graph into which this layer is being cloned.
- EqualLayer* Clone(Graph& graph) const override;
-
- void Accept(ILayerVisitor& visitor) const override;
-
-protected:
- /// Constructor to create a EqualLayer.
- /// @param [in] name Optional name for the layer.
- EqualLayer(const char* name);
-
- /// Default destructor
- ~EqualLayer() = default;
-};
-
-} //namespace armnn
diff --git a/src/armnn/layers/GreaterLayer.cpp b/src/armnn/layers/GreaterLayer.cpp
deleted file mode 100644
index a9fe5e0d8c..0000000000
--- a/src/armnn/layers/GreaterLayer.cpp
+++ /dev/null
@@ -1,39 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#include "GreaterLayer.hpp"
-
-#include "LayerCloneBase.hpp"
-
-#include <armnn/TypesUtils.hpp>
-#include <backendsCommon/WorkloadData.hpp>
-#include <backendsCommon/WorkloadFactory.hpp>
-
-namespace armnn
-{
-
-GreaterLayer::GreaterLayer(const char* name)
- : ElementwiseBaseLayer(2, 1, LayerType::Greater, name)
-{
-}
-
-std::unique_ptr<IWorkload> GreaterLayer::CreateWorkload(const Graph& graph,
- const IWorkloadFactory& factory) const
-{
- GreaterQueueDescriptor descriptor;
- return factory.CreateGreater(descriptor, PrepInfoAndDesc(descriptor, graph));
-}
-
-GreaterLayer* GreaterLayer::Clone(Graph& graph) const
-{
- return CloneBase<GreaterLayer>(graph, GetName());
-}
-
-void GreaterLayer::Accept(ILayerVisitor& visitor) const
-{
- visitor.VisitGreaterLayer(this, GetName());
-}
-
-} // namespace armnn
diff --git a/src/armnn/layers/GreaterLayer.hpp b/src/armnn/layers/GreaterLayer.hpp
deleted file mode 100644
index bdee948d6e..0000000000
--- a/src/armnn/layers/GreaterLayer.hpp
+++ /dev/null
@@ -1,39 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-
-#pragma once
-
-#include "ElementwiseBaseLayer.hpp"
-
-namespace armnn
-{
-
-/// This layer represents a greater operation.
-class GreaterLayer : public ElementwiseBaseLayer
-{
-public:
- /// Makes a workload for the Greater type.
- /// @param [in] graph The graph where this layer can be found.
- /// @param [in] factory The workload factory which will create the workload.
- /// @return A pointer to the created workload, or nullptr if not created.
- virtual std::unique_ptr<IWorkload> CreateWorkload(const Graph& graph,
- const IWorkloadFactory& factory) const override;
-
- /// Creates a dynamically-allocated copy of this layer.
- /// @param [in] graph The graph into which this layer is being cloned.
- GreaterLayer* Clone(Graph& graph) const override;
-
- void Accept(ILayerVisitor& visitor) const override;
-
-protected:
- /// Constructor to create a GreaterLayer.
- /// @param [in] name Optional name for the layer.
- GreaterLayer(const char* name);
-
- /// Default destructor
- ~GreaterLayer() = default;
-};
-
-} //namespace armnn
diff --git a/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp b/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp
index 0b126235e8..36bbd36792 100644
--- a/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp
+++ b/src/armnn/test/TestNameAndDescriptorLayerVisitor.cpp
@@ -65,6 +65,12 @@ armnn::BatchToSpaceNdDescriptor GetDescriptor<armnn::BatchToSpaceNdDescriptor>()
}
template<>
+armnn::ComparisonDescriptor GetDescriptor<armnn::ComparisonDescriptor>()
+{
+ return armnn::ComparisonDescriptor(armnn::ComparisonOperation::GreaterOrEqual);
+}
+
+template<>
armnn::ConcatDescriptor GetDescriptor<armnn::ConcatDescriptor>()
{
armnn::ConcatDescriptor descriptor(2, 2);
@@ -243,6 +249,7 @@ TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Activation)
TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(ArgMinMax)
TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(DepthToSpace)
TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(BatchToSpaceNd)
+TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Comparison)
TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Concat)
TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(InstanceNormalization)
TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(L2Normalization)
@@ -262,4 +269,4 @@ TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Splitter)
TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(Stack)
TEST_SUITE_NAME_AND_DESCRIPTOR_LAYER_VISITOR(StridedSlice)
-BOOST_AUTO_TEST_SUITE_END() \ No newline at end of file
+BOOST_AUTO_TEST_SUITE_END()
diff --git a/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp b/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp
index aefcba5f59..b1f7f57075 100644
--- a/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp
+++ b/src/armnn/test/TestNameAndDescriptorLayerVisitor.hpp
@@ -46,6 +46,7 @@ public: \
DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(Activation)
DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(ArgMinMax)
DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(BatchToSpaceNd)
+DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(Comparison)
DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(Concat)
DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(DepthToSpace)
DECLARE_TEST_NAME_AND_DESCRIPTOR_LAYER_VISITOR_CLASS(InstanceNormalization)
diff --git a/src/armnn/test/TestNameOnlyLayerVisitor.cpp b/src/armnn/test/TestNameOnlyLayerVisitor.cpp
index 6bc2dc7c65..32de94e7ef 100644
--- a/src/armnn/test/TestNameOnlyLayerVisitor.cpp
+++ b/src/armnn/test/TestNameOnlyLayerVisitor.cpp
@@ -42,10 +42,8 @@ TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Abs)
TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Addition)
TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Dequantize)
TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Division)
-TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Equal)
TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Floor)
TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Gather)
-TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Greater)
TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Maximum)
TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Merge)
TEST_SUITE_NAME_ONLY_LAYER_VISITOR(Minimum)
diff --git a/src/armnn/test/TestNameOnlyLayerVisitor.hpp b/src/armnn/test/TestNameOnlyLayerVisitor.hpp
index a772cb3283..c770b5e9e0 100644
--- a/src/armnn/test/TestNameOnlyLayerVisitor.hpp
+++ b/src/armnn/test/TestNameOnlyLayerVisitor.hpp
@@ -29,10 +29,8 @@ DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Abs)
DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Addition)
DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Dequantize)
DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Division)
-DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Equal)
DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Floor)
DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Gather)
-DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Greater)
DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Maximum)
DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Merge)
DECLARE_TEST_NAME_ONLY_LAYER_VISITOR_CLASS(Minimum)
diff --git a/src/armnnDeserializer/Deserializer.cpp b/src/armnnDeserializer/Deserializer.cpp
index 67836c5843..0b18ccd051 100644
--- a/src/armnnDeserializer/Deserializer.cpp
+++ b/src/armnnDeserializer/Deserializer.cpp
@@ -1270,7 +1270,8 @@ void Deserializer::ParseEqual(GraphPtr graph, unsigned int layerIndex)
CHECK_VALID_SIZE(outputs.size(), 1);
auto layerName = GetLayerName(graph, layerIndex);
- IConnectableLayer* layer = m_Network->AddEqualLayer(layerName.c_str());
+ armnn::ComparisonDescriptor descriptor(armnn::ComparisonOperation::Equal);
+ IConnectableLayer* layer = m_Network->AddComparisonLayer(descriptor, layerName.c_str());
armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
@@ -1290,7 +1291,8 @@ void Deserializer::ParseGreater(GraphPtr graph, unsigned int layerIndex)
CHECK_VALID_SIZE(outputs.size(), 1);
auto layerName = GetLayerName(graph, layerIndex);
- IConnectableLayer* layer = m_Network->AddGreaterLayer(layerName.c_str());
+ armnn::ComparisonDescriptor descriptor(armnn::ComparisonOperation::Greater);
+ IConnectableLayer* layer = m_Network->AddComparisonLayer(descriptor, layerName.c_str());
armnn::TensorInfo outputTensorInfo = ToTensorInfo(outputs[0]);
layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
diff --git a/src/armnnSerializer/Serializer.cpp b/src/armnnSerializer/Serializer.cpp
index 11f833c5b7..5c9855f87e 100644
--- a/src/armnnSerializer/Serializer.cpp
+++ b/src/armnnSerializer/Serializer.cpp
@@ -240,6 +240,13 @@ void SerializerVisitor::VisitBatchNormalizationLayer(const armnn::IConnectableLa
CreateAnyLayer(fbBatchNormalizationLayer.o, serializer::Layer::Layer_BatchNormalizationLayer);
}
+void SerializerVisitor::VisitComparisonLayer(const armnn::IConnectableLayer* layer,
+ const armnn::ComparisonDescriptor& descriptor,
+ const char* name)
+{
+ throw armnn::UnimplementedException("SerializerVisitor::VisitComparisonLayer() is not implemented");
+}
+
// Build FlatBuffer for Constant Layer
void SerializerVisitor::VisitConstantLayer(const armnn::IConnectableLayer* layer,
const armnn::ConstTensor& input,
diff --git a/src/armnnSerializer/Serializer.hpp b/src/armnnSerializer/Serializer.hpp
index 8c13245aeb..79dc17ba01 100644
--- a/src/armnnSerializer/Serializer.hpp
+++ b/src/armnnSerializer/Serializer.hpp
@@ -68,6 +68,10 @@ public:
const armnn::ConstTensor& gamma,
const char* name = nullptr) override;
+ void VisitComparisonLayer(const armnn::IConnectableLayer* layer,
+ const armnn::ComparisonDescriptor& descriptor,
+ const char* name = nullptr) override;
+
void VisitConcatLayer(const armnn::IConnectableLayer* layer,
const armnn::ConcatDescriptor& concatDescriptor,
const char* name = nullptr) override;
@@ -103,6 +107,7 @@ public:
void VisitDivisionLayer(const armnn::IConnectableLayer* layer,
const char* name = nullptr) override;
+ ARMNN_DEPRECATED_MSG("Use VisitComparisonLayer instead")
void VisitEqualLayer(const armnn::IConnectableLayer* layer,
const char* name = nullptr) override;
@@ -118,6 +123,7 @@ public:
void VisitGatherLayer(const armnn::IConnectableLayer* layer,
const char* name = nullptr) override;
+ ARMNN_DEPRECATED_MSG("Use VisitComparisonLayer instead")
void VisitGreaterLayer(const armnn::IConnectableLayer* layer,
const char* name = nullptr) override;
diff --git a/src/armnnSerializer/test/SerializerTests.cpp b/src/armnnSerializer/test/SerializerTests.cpp
index a70c891849..58f56f484f 100644
--- a/src/armnnSerializer/test/SerializerTests.cpp
+++ b/src/armnnSerializer/test/SerializerTests.cpp
@@ -1009,48 +1009,6 @@ BOOST_AUTO_TEST_CASE(SerializeDivision)
deserializedNetwork->Accept(verifier);
}
-BOOST_AUTO_TEST_CASE(SerializeEqual)
-{
- class EqualLayerVerifier : public LayerVerifierBase
- {
- public:
- EqualLayerVerifier(const std::string& layerName,
- const std::vector<armnn::TensorInfo>& inputInfos,
- const std::vector<armnn::TensorInfo>& outputInfos)
- : LayerVerifierBase(layerName, inputInfos, outputInfos) {}
-
- void VisitEqualLayer(const armnn::IConnectableLayer* layer, const char* name) override
- {
- VerifyNameAndConnections(layer, name);
- }
- };
-
- const std::string layerName("equal");
- const armnn::TensorInfo inputTensorInfo1 = armnn::TensorInfo({2, 1, 2, 4}, armnn::DataType::Float32);
- const armnn::TensorInfo inputTensorInfo2 = armnn::TensorInfo({2, 1, 2, 4}, armnn::DataType::Float32);
- const armnn::TensorInfo outputTensorInfo = armnn::TensorInfo({2, 1, 2, 4}, armnn::DataType::Boolean);
-
- armnn::INetworkPtr network = armnn::INetwork::Create();
- armnn::IConnectableLayer* const inputLayer1 = network->AddInputLayer(0);
- armnn::IConnectableLayer* const inputLayer2 = network->AddInputLayer(1);
- armnn::IConnectableLayer* const equalLayer = network->AddEqualLayer(layerName.c_str());
- armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
-
- inputLayer1->GetOutputSlot(0).Connect(equalLayer->GetInputSlot(0));
- inputLayer2->GetOutputSlot(0).Connect(equalLayer->GetInputSlot(1));
- equalLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
-
- inputLayer1->GetOutputSlot(0).SetTensorInfo(inputTensorInfo1);
- inputLayer2->GetOutputSlot(0).SetTensorInfo(inputTensorInfo2);
- equalLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
-
- armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
- BOOST_CHECK(deserializedNetwork);
-
- EqualLayerVerifier verifier(layerName, {inputTensorInfo1, inputTensorInfo2}, {outputTensorInfo});
- deserializedNetwork->Accept(verifier);
-}
-
BOOST_AUTO_TEST_CASE(SerializeFloor)
{
class FloorLayerVerifier : public LayerVerifierBase
@@ -1225,48 +1183,6 @@ BOOST_AUTO_TEST_CASE(SerializeGather)
deserializedNetwork->Accept(verifier);
}
-BOOST_AUTO_TEST_CASE(SerializeGreater)
-{
- class GreaterLayerVerifier : public LayerVerifierBase
- {
- public:
- GreaterLayerVerifier(const std::string& layerName,
- const std::vector<armnn::TensorInfo>& inputInfos,
- const std::vector<armnn::TensorInfo>& outputInfos)
- : LayerVerifierBase(layerName, inputInfos, outputInfos) {}
-
- void VisitGreaterLayer(const armnn::IConnectableLayer* layer, const char* name) override
- {
- VerifyNameAndConnections(layer, name);
- }
- };
-
- const std::string layerName("greater");
- const armnn::TensorInfo inputTensorInfo1({ 1, 2, 2, 2 }, armnn::DataType::Float32);
- const armnn::TensorInfo inputTensorInfo2({ 1, 2, 2, 2 }, armnn::DataType::Float32);
- const armnn::TensorInfo outputTensorInfo({ 1, 2, 2, 2 }, armnn::DataType::Boolean);
-
- armnn::INetworkPtr network = armnn::INetwork::Create();
- armnn::IConnectableLayer* const inputLayer1 = network->AddInputLayer(0);
- armnn::IConnectableLayer* const inputLayer2 = network->AddInputLayer(1);
- armnn::IConnectableLayer* const greaterLayer = network->AddGreaterLayer(layerName.c_str());
- armnn::IConnectableLayer* const outputLayer = network->AddOutputLayer(0);
-
- inputLayer1->GetOutputSlot(0).Connect(greaterLayer->GetInputSlot(0));
- inputLayer2->GetOutputSlot(0).Connect(greaterLayer->GetInputSlot(1));
- greaterLayer->GetOutputSlot(0).Connect(outputLayer->GetInputSlot(0));
-
- inputLayer1->GetOutputSlot(0).SetTensorInfo(inputTensorInfo1);
- inputLayer2->GetOutputSlot(0).SetTensorInfo(inputTensorInfo2);
- greaterLayer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo);
-
- armnn::INetworkPtr deserializedNetwork = DeserializeNetwork(SerializeNetwork(*network));
- BOOST_CHECK(deserializedNetwork);
-
- GreaterLayerVerifier verifier(layerName, {inputTensorInfo1, inputTensorInfo2}, {outputTensorInfo});
- deserializedNetwork->Accept(verifier);
-}
-
BOOST_AUTO_TEST_CASE(SerializeInstanceNormalization)
{
class InstanceNormalizationLayerVerifier : public LayerVerifierBase
diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp
index 76d25d1d05..d085ed84e3 100755
--- a/src/armnnTfParser/TfParser.cpp
+++ b/src/armnnTfParser/TfParser.cpp
@@ -1862,7 +1862,8 @@ ParsedTfOperationPtr TfParser::ParseGreater(const tensorflow::NodeDef& nodeDef,
IOutputSlot* input0Slot = inputLayers.first;
IOutputSlot* input1Slot = inputLayers.second;
- IConnectableLayer* const layer = m_Network->AddGreaterLayer(nodeDef.name().c_str());
+ ComparisonDescriptor descriptor(ComparisonOperation::Greater);
+ IConnectableLayer* const layer = m_Network->AddComparisonLayer(descriptor, nodeDef.name().c_str());
return ProcessComparisonLayer(input0Slot, input1Slot, layer, nodeDef);
}
@@ -1874,7 +1875,8 @@ ParsedTfOperationPtr TfParser::ParseEqual(const tensorflow::NodeDef& nodeDef,
IOutputSlot* input0Slot = inputLayers.first;
IOutputSlot* input1Slot = inputLayers.second;
- IConnectableLayer* const layer = m_Network->AddEqualLayer(nodeDef.name().c_str());
+ ComparisonDescriptor descriptor(ComparisonOperation::Equal);
+ IConnectableLayer* const layer = m_Network->AddComparisonLayer(descriptor, nodeDef.name().c_str());
return ProcessComparisonLayer(input0Slot, input1Slot, layer, nodeDef);
}
diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp
index 7d5555ce68..358106e5e9 100644
--- a/src/backends/backendsCommon/LayerSupportBase.cpp
+++ b/src/backends/backendsCommon/LayerSupportBase.cpp
@@ -84,6 +84,15 @@ bool LayerSupportBase::IsBatchToSpaceNdSupported(const TensorInfo& input,
return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
}
+bool LayerSupportBase::IsComparisonSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ const ComparisonDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported);
+}
+
bool LayerSupportBase::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
const TensorInfo& output,
const OriginsDescriptor& descriptor,
diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp
index cb660f5c2b..d4c37c1a91 100644
--- a/src/backends/backendsCommon/LayerSupportBase.hpp
+++ b/src/backends/backendsCommon/LayerSupportBase.hpp
@@ -46,6 +46,12 @@ public:
const BatchToSpaceNdDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsComparisonSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ const ComparisonDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
bool IsConcatSupported(const std::vector<const TensorInfo*> inputs,
const TensorInfo& output,
const OriginsDescriptor& descriptor,
@@ -108,6 +114,7 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead")
bool IsEqualSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
@@ -133,6 +140,7 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead")
bool IsGreaterSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index b8d4f0dfff..cfb38b4820 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -2779,4 +2779,28 @@ void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con
}
}
+void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ const std::string descriptorName{"ComparisonQueueDescriptor"};
+
+ ValidateNumInputs(workloadInfo, descriptorName, 2);
+ ValidateNumOutputs(workloadInfo, descriptorName, 1);
+
+ const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0];
+ const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1];
+ const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0];
+
+ ValidateBroadcastTensorShapesMatch(inputTensorInfo0,
+ inputTensorInfo1,
+ outputTensorInfo,
+ descriptorName,
+ "input_0",
+ "input_1");
+
+ if (outputTensorInfo.GetDataType() != DataType::Boolean)
+ {
+ throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean.");
+ }
+}
+
} // namespace armnn
diff --git a/src/backends/backendsCommon/WorkloadData.hpp b/src/backends/backendsCommon/WorkloadData.hpp
index 5a3600fc71..b45a1718f6 100644
--- a/src/backends/backendsCommon/WorkloadData.hpp
+++ b/src/backends/backendsCommon/WorkloadData.hpp
@@ -553,4 +553,9 @@ struct DepthToSpaceQueueDescriptor : QueueDescriptorWithParameters<DepthToSpaceD
void Validate(const WorkloadInfo& workloadInfo) const;
};
+struct ComparisonQueueDescriptor : QueueDescriptorWithParameters<ComparisonDescriptor>
+{
+ void Validate(const WorkloadInfo& workloadInfo) const;
+};
+
} // namespace armnn
diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp
index f19b48491a..30dfa023f9 100644
--- a/src/backends/backendsCommon/WorkloadFactory.cpp
+++ b/src/backends/backendsCommon/WorkloadFactory.cpp
@@ -148,6 +148,21 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
reason);
break;
}
+ case LayerType::Comparison:
+ {
+ auto cLayer = boost::polymorphic_downcast<const ComparisonLayer*>(&layer);
+
+ const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
+ const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
+ const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
+
+ result = layerSupportObject->IsComparisonSupported(OverrideDataType(input0, dataType),
+ OverrideDataType(input1, dataType),
+ OverrideDataType(output, DataType::Boolean),
+ cLayer->GetParameters(),
+ reason);
+ break;
+ }
case LayerType::Constant:
{
const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
@@ -268,17 +283,6 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
reason);
break;
}
- case LayerType::Equal:
- {
- const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
- const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
- const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
- result = layerSupportObject->IsEqualSupported(OverrideDataType(input0, dataType),
- OverrideDataType(input1, dataType),
- OverrideDataType(output, dataType),
- reason);
- break;
- }
case LayerType::FakeQuantization:
{
auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
@@ -957,17 +961,6 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
reason);
break;
}
- case LayerType::Greater:
- {
- const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
- const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
- const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
- result = layerSupportObject->IsGreaterSupported(OverrideDataType(input0, dataType),
- OverrideDataType(input1, dataType),
- OverrideDataType(output, DataType::Boolean),
- reason);
- break;
- }
case LayerType::Prelu:
{
const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
@@ -1065,6 +1058,12 @@ std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToS
return std::unique_ptr<IWorkload>();
}
+std::unique_ptr<IWorkload> IWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ return std::unique_ptr<IWorkload>();
+}
+
std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
diff --git a/src/backends/backendsCommon/WorkloadFactory.hpp b/src/backends/backendsCommon/WorkloadFactory.hpp
index fa7a9d46a8..819b8c768b 100644
--- a/src/backends/backendsCommon/WorkloadFactory.hpp
+++ b/src/backends/backendsCommon/WorkloadFactory.hpp
@@ -69,6 +69,9 @@ public:
virtual std::unique_ptr<IWorkload> CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
const WorkloadInfo& Info) const;
+ virtual std::unique_ptr<IWorkload> CreateComparison(const ComparisonQueueDescriptor& descriptor,
+ const WorkloadInfo& Info) const;
+
virtual std::unique_ptr<IWorkload> CreateConcat(const ConcatQueueDescriptor& descriptor,
const WorkloadInfo& info) const;
@@ -102,6 +105,7 @@ public:
virtual std::unique_ptr<IWorkload> CreateDivision(const DivisionQueueDescriptor& descriptor,
const WorkloadInfo& info) const;
+ ARMNN_DEPRECATED_MSG("Use CreateComparison instead")
virtual std::unique_ptr<IWorkload> CreateEqual(const EqualQueueDescriptor& descriptor,
const WorkloadInfo& Info) const;
@@ -117,6 +121,7 @@ public:
virtual std::unique_ptr<IWorkload> CreateGather(const GatherQueueDescriptor& descriptor,
const WorkloadInfo& info) const;
+ ARMNN_DEPRECATED_MSG("Use CreateComparison instead")
virtual std::unique_ptr<IWorkload> CreateGreater(const GreaterQueueDescriptor& descriptor,
const WorkloadInfo& info) const;
diff --git a/src/backends/backendsCommon/test/ArithmeticTestImpl.hpp b/src/backends/backendsCommon/test/ArithmeticTestImpl.hpp
deleted file mode 100644
index d0e85dd31d..0000000000
--- a/src/backends/backendsCommon/test/ArithmeticTestImpl.hpp
+++ /dev/null
@@ -1,113 +0,0 @@
-//
-// Copyright © 2017 Arm Ltd. All rights reserved.
-// SPDX-License-Identifier: MIT
-//
-#pragma once
-
-#include "CommonTestUtils.hpp"
-
-#include <ResolveType.hpp>
-
-#include <armnn/INetwork.hpp>
-
-#include <boost/test/unit_test.hpp>
-
-#include <vector>
-
-namespace
-{
-
-template<armnn::DataType ArmnnTypeInput, armnn::DataType ArmnnTypeOutput>
-INetworkPtr CreateArithmeticNetwork(const std::vector<TensorShape>& inputShapes,
- const TensorShape& outputShape,
- const LayerType type,
- const float qScale = 1.0f,
- const int32_t qOffset = 0)
-{
- using namespace armnn;
-
- // Builds up the structure of the network.
- INetworkPtr net(INetwork::Create());
-
- IConnectableLayer* arithmeticLayer = nullptr;
-
- switch(type){
- case LayerType::Equal: arithmeticLayer = net->AddEqualLayer("equal"); break;
- case LayerType::Greater: arithmeticLayer = net->AddGreaterLayer("greater"); break;
- default: BOOST_TEST_FAIL("Non-Arithmetic layer type called.");
- }
-
- for (unsigned int i = 0; i < inputShapes.size(); ++i)
- {
- TensorInfo inputTensorInfo(inputShapes[i], ArmnnTypeInput, qScale, qOffset);
- IConnectableLayer* input = net->AddInputLayer(boost::numeric_cast<LayerBindingId>(i));
- Connect(input, arithmeticLayer, inputTensorInfo, 0, i);
- }
-
- TensorInfo outputTensorInfo(outputShape, ArmnnTypeOutput, qScale, qOffset);
- IConnectableLayer* output = net->AddOutputLayer(0, "output");
- Connect(arithmeticLayer, output, outputTensorInfo, 0, 0);
-
- return net;
-}
-
-template<armnn::DataType ArmnnInputType,
- armnn::DataType ArmnnOutputType,
- typename TInput = armnn::ResolveType<ArmnnInputType>,
- typename TOutput = armnn::ResolveType<ArmnnOutputType>>
-void ArithmeticSimpleEndToEnd(const std::vector<BackendId>& backends,
- const LayerType type,
- const std::vector<TOutput> expectedOutput)
-{
- using namespace armnn;
-
- const std::vector<TensorShape> inputShapes{{ 2, 2, 2, 2 }, { 2, 2, 2, 2 }};
- const TensorShape& outputShape = { 2, 2, 2, 2 };
-
- // Builds up the structure of the network
- INetworkPtr net = CreateArithmeticNetwork<ArmnnInputType, ArmnnOutputType>(inputShapes, outputShape, type);
-
- BOOST_TEST_CHECKPOINT("create a network");
-
- const std::vector<TInput> input0({ 1, 1, 1, 1, 5, 5, 5, 5,
- 3, 3, 3, 3, 4, 4, 4, 4 });
-
- const std::vector<TInput> input1({ 1, 1, 1, 1, 3, 3, 3, 3,
- 5, 5, 5, 5, 4, 4, 4, 4 });
-
- std::map<int, std::vector<TInput>> inputTensorData = {{ 0, input0 }, { 1, input1 }};
- std::map<int, std::vector<TOutput>> expectedOutputData = {{ 0, expectedOutput }};
-
- EndToEndLayerTestImpl<ArmnnInputType, ArmnnOutputType>(move(net), inputTensorData, expectedOutputData, backends);
-}
-
-template<armnn::DataType ArmnnInputType,
- armnn::DataType ArmnnOutputType,
- typename TInput = armnn::ResolveType<ArmnnInputType>,
- typename TOutput = armnn::ResolveType<ArmnnOutputType>>
-void ArithmeticBroadcastEndToEnd(const std::vector<BackendId>& backends,
- const LayerType type,
- const std::vector<TOutput> expectedOutput)
-{
- using namespace armnn;
-
- const std::vector<TensorShape> inputShapes{{ 1, 2, 2, 3 }, { 1, 1, 1, 3 }};
- const TensorShape& outputShape = { 1, 2, 2, 3 };
-
- // Builds up the structure of the network
- INetworkPtr net = CreateArithmeticNetwork<ArmnnInputType, ArmnnOutputType>(inputShapes, outputShape, type);
-
- BOOST_TEST_CHECKPOINT("create a network");
-
- const std::vector<TInput> input0({ 1, 2, 3, 1, 0, 6,
- 7, 8, 9, 10, 11, 12 });
-
- const std::vector<TInput> input1({ 1, 1, 3 });
-
- std::map<int, std::vector<TInput>> inputTensorData = {{ 0, input0 }, { 1, input1 }};
- std::map<int, std::vector<TOutput>> expectedOutputData = {{ 0, expectedOutput }};
-
- EndToEndLayerTestImpl<ArmnnInputType, ArmnnOutputType>(move(net), inputTensorData, expectedOutputData, backends);
-}
-
-} // anonymous namespace
diff --git a/src/backends/backendsCommon/test/CMakeLists.txt b/src/backends/backendsCommon/test/CMakeLists.txt
index 797dc90952..7449e69dcd 100644
--- a/src/backends/backendsCommon/test/CMakeLists.txt
+++ b/src/backends/backendsCommon/test/CMakeLists.txt
@@ -10,6 +10,7 @@ list(APPEND armnnBackendsCommonUnitTests_sources
BackendRegistryTests.cpp
CommonTestUtils.cpp
CommonTestUtils.hpp
+ ComparisonEndToEndTestImpl.hpp
DataLayoutUtils.hpp
DataTypeUtils.hpp
DepthToSpaceEndToEndTestImpl.hpp
@@ -58,6 +59,7 @@ list(APPEND armnnBackendsCommonUnitTests_sources
layerTests/BatchNormalizationTestImpl.cpp
layerTests/BatchNormalizationTestImpl.hpp
layerTests/BatchToSpaceNdTestImpl.hpp
+ layerTests/ComparisonTestImpl.hpp
layerTests/ConcatTestImpl.cpp
layerTests/ConcatTestImpl.hpp
layerTests/ConstantTestImpl.cpp
diff --git a/src/backends/backendsCommon/test/ComparisonEndToEndTestImpl.hpp b/src/backends/backendsCommon/test/ComparisonEndToEndTestImpl.hpp
new file mode 100644
index 0000000000..dc53b7b246
--- /dev/null
+++ b/src/backends/backendsCommon/test/ComparisonEndToEndTestImpl.hpp
@@ -0,0 +1,103 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+#pragma once
+
+#include "CommonTestUtils.hpp"
+
+#include <ResolveType.hpp>
+
+#include <armnn/INetwork.hpp>
+
+#include <boost/test/unit_test.hpp>
+
+#include <vector>
+
+namespace
+{
+
+template<armnn::DataType ArmnnTypeInput>
+INetworkPtr CreateComparisonNetwork(const std::vector<TensorShape>& inputShapes,
+ const TensorShape& outputShape,
+ ComparisonOperation operation,
+ const float qScale = 1.0f,
+ const int32_t qOffset = 0)
+{
+ using namespace armnn;
+
+ INetworkPtr net(INetwork::Create());
+
+ ComparisonDescriptor descriptor(operation);
+ IConnectableLayer* comparisonLayer = net->AddComparisonLayer(descriptor, "comparison");
+
+ for (unsigned int i = 0; i < inputShapes.size(); ++i)
+ {
+ TensorInfo inputTensorInfo(inputShapes[i], ArmnnTypeInput, qScale, qOffset);
+ IConnectableLayer* input = net->AddInputLayer(boost::numeric_cast<LayerBindingId>(i));
+ Connect(input, comparisonLayer, inputTensorInfo, 0, i);
+ }
+
+ TensorInfo outputTensorInfo(outputShape, DataType::Boolean, qScale, qOffset);
+ IConnectableLayer* output = net->AddOutputLayer(0, "output");
+ Connect(comparisonLayer, output, outputTensorInfo, 0, 0);
+
+ return net;
+}
+
+template<armnn::DataType ArmnnInType,
+ typename TInput = armnn::ResolveType<ArmnnInType>>
+void ComparisonSimpleEndToEnd(const std::vector<BackendId>& backends,
+ ComparisonOperation operation,
+ const std::vector<uint8_t> expectedOutput)
+{
+ using namespace armnn;
+
+ const std::vector<TensorShape> inputShapes{{ 2, 2, 2, 2 }, { 2, 2, 2, 2 }};
+ const TensorShape& outputShape = { 2, 2, 2, 2 };
+
+ // Builds up the structure of the network
+ INetworkPtr net = CreateComparisonNetwork<ArmnnInType>(inputShapes, outputShape, operation);
+
+ BOOST_TEST_CHECKPOINT("create a network");
+
+ const std::vector<TInput> input0({ 1, 1, 1, 1, 5, 5, 5, 5,
+ 3, 3, 3, 3, 4, 4, 4, 4 });
+
+ const std::vector<TInput> input1({ 1, 1, 1, 1, 3, 3, 3, 3,
+ 5, 5, 5, 5, 4, 4, 4, 4 });
+
+ std::map<int, std::vector<TInput>> inputTensorData = {{ 0, input0 }, { 1, input1 }};
+ std::map<int, std::vector<uint8_t>> expectedOutputData = {{ 0, expectedOutput }};
+
+ EndToEndLayerTestImpl<ArmnnInType, DataType::Boolean>(move(net), inputTensorData, expectedOutputData, backends);
+}
+
+template<armnn::DataType ArmnnInType,
+ typename TInput = armnn::ResolveType<ArmnnInType>>
+void ComparisonBroadcastEndToEnd(const std::vector<BackendId>& backends,
+ ComparisonOperation operation,
+ const std::vector<uint8_t> expectedOutput)
+{
+ using namespace armnn;
+
+ const std::vector<TensorShape> inputShapes{{ 1, 2, 2, 3 }, { 1, 1, 1, 3 }};
+ const TensorShape& outputShape = { 1, 2, 2, 3 };
+
+ // Builds up the structure of the network
+ INetworkPtr net = CreateComparisonNetwork<ArmnnInType>(inputShapes, outputShape, operation);
+
+ BOOST_TEST_CHECKPOINT("create a network");
+
+ const std::vector<TInput> input0({ 1, 2, 3, 1, 0, 6,
+ 7, 8, 9, 10, 11, 12 });
+
+ const std::vector<TInput> input1({ 1, 1, 3 });
+
+ std::map<int, std::vector<TInput>> inputTensorData = {{ 0, input0 }, { 1, input1 }};
+ std::map<int, std::vector<uint8_t>> expectedOutputData = {{ 0, expectedOutput }};
+
+ EndToEndLayerTestImpl<ArmnnInType, DataType::Boolean>(move(net), inputTensorData, expectedOutputData, backends);
+}
+
+} // anonymous namespace
diff --git a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
index 907285c5cf..9bddae9759 100644
--- a/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
+++ b/src/backends/backendsCommon/test/IsLayerSupportedTestImpl.hpp
@@ -397,6 +397,8 @@ DECLARE_LAYER_POLICY_2_PARAM(BatchNormalization)
DECLARE_LAYER_POLICY_2_PARAM(BatchToSpaceNd)
+DECLARE_LAYER_POLICY_2_PARAM(Comparison)
+
DECLARE_LAYER_POLICY_2_PARAM(Concat)
DECLARE_LAYER_POLICY_1_PARAM(Constant)
@@ -421,8 +423,6 @@ DECLARE_LAYER_POLICY_1_PARAM(Dequantize)
DECLARE_LAYER_POLICY_2_PARAM(DetectionPostProcess)
-DECLARE_LAYER_POLICY_1_PARAM(Equal)
-
DECLARE_LAYER_POLICY_2_PARAM(FakeQuantization)
DECLARE_LAYER_POLICY_1_PARAM(Floor)
@@ -431,8 +431,6 @@ DECLARE_LAYER_POLICY_2_PARAM(FullyConnected)
DECLARE_LAYER_POLICY_1_PARAM(Gather)
-DECLARE_LAYER_POLICY_1_PARAM(Greater)
-
DECLARE_LAYER_POLICY_CUSTOM_PARAM(Input, armnn::LayerBindingId)
DECLARE_LAYER_POLICY_2_PARAM(InstanceNormalization)
diff --git a/src/backends/backendsCommon/test/layerTests/ComparisonTestImpl.hpp b/src/backends/backendsCommon/test/layerTests/ComparisonTestImpl.hpp
new file mode 100644
index 0000000000..6ce9b306c0
--- /dev/null
+++ b/src/backends/backendsCommon/test/layerTests/ComparisonTestImpl.hpp
@@ -0,0 +1,126 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "LayerTestResult.hpp"
+
+#include <armnn/ArmNN.hpp>
+
+#include <ResolveType.hpp>
+
+#include <backendsCommon/IBackendInternal.hpp>
+#include <backendsCommon/Workload.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+#include <backendsCommon/WorkloadFactory.hpp>
+
+#include <backendsCommon/test/TensorCopyUtils.hpp>
+#include <backendsCommon/test/WorkloadTestUtils.hpp>
+
+#include <test/TensorHelpers.hpp>
+
+template <std::size_t NumDims,
+ armnn::DataType ArmnnInType,
+ typename InType = armnn::ResolveType<ArmnnInType>>
+LayerTestResult<uint8_t, NumDims> ComparisonTestImpl(
+ armnn::IWorkloadFactory & workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,
+ const armnn::ComparisonDescriptor& descriptor,
+ const unsigned int shape0[NumDims],
+ std::vector<InType> values0,
+ float quantScale0,
+ int quantOffset0,
+ const unsigned int shape1[NumDims],
+ std::vector<InType> values1,
+ float quantScale1,
+ int quantOffset1,
+ const unsigned int outShape[NumDims],
+ std::vector<uint8_t> outValues,
+ float outQuantScale,
+ int outQuantOffset)
+{
+ armnn::TensorInfo inputTensorInfo0{NumDims, shape0, ArmnnInType};
+ armnn::TensorInfo inputTensorInfo1{NumDims, shape1, ArmnnInType};
+ armnn::TensorInfo outputTensorInfo{NumDims, outShape, armnn::DataType::Boolean};
+
+ auto input0 = MakeTensor<InType, NumDims>(inputTensorInfo0, values0);
+ auto input1 = MakeTensor<InType, NumDims>(inputTensorInfo1, values1);
+
+ inputTensorInfo0.SetQuantizationScale(quantScale0);
+ inputTensorInfo0.SetQuantizationOffset(quantOffset0);
+
+ inputTensorInfo1.SetQuantizationScale(quantScale1);
+ inputTensorInfo1.SetQuantizationOffset(quantOffset1);
+
+ outputTensorInfo.SetQuantizationScale(outQuantScale);
+ outputTensorInfo.SetQuantizationOffset(outQuantOffset);
+
+ LayerTestResult<uint8_t, NumDims> ret(outputTensorInfo);
+
+ std::unique_ptr<armnn::ITensorHandle> inputHandle0 = workloadFactory.CreateTensorHandle(inputTensorInfo0);
+ std::unique_ptr<armnn::ITensorHandle> inputHandle1 = workloadFactory.CreateTensorHandle(inputTensorInfo1);
+ std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
+
+ armnn::ComparisonQueueDescriptor qDescriptor;
+ qDescriptor.m_Parameters = descriptor;
+
+ armnn::WorkloadInfo info;
+ AddInputToWorkload(qDescriptor, info, inputTensorInfo0, inputHandle0.get());
+ AddInputToWorkload(qDescriptor, info, inputTensorInfo1, inputHandle1.get());
+ AddOutputToWorkload(qDescriptor, info, outputTensorInfo, outputHandle.get());
+
+ std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateComparison(qDescriptor, info);
+
+ inputHandle0->Allocate();
+ inputHandle1->Allocate();
+ outputHandle->Allocate();
+
+ CopyDataToITensorHandle(inputHandle0.get(), input0.origin());
+ CopyDataToITensorHandle(inputHandle1.get(), input1.origin());
+
+ workload->PostAllocationConfigure();
+ ExecuteWorkload(*workload, memoryManager);
+
+ CopyDataFromITensorHandle(ret.output.origin(), outputHandle.get());
+
+ ret.outputExpected = MakeTensor<uint8_t, NumDims>(outputTensorInfo, outValues);
+ ret.compareBoolean = true;
+
+ return ret;
+}
+
+template <std::size_t NumDims,
+ armnn::DataType ArmnnInType,
+ typename InType = armnn::ResolveType<ArmnnInType>>
+LayerTestResult<uint8_t, NumDims> ComparisonTestImpl(
+ armnn::IWorkloadFactory & workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,
+ const armnn::ComparisonDescriptor& descriptor,
+ const unsigned int shape0[NumDims],
+ std::vector<InType> values0,
+ const unsigned int shape1[NumDims],
+ std::vector<InType> values1,
+ const unsigned int outShape[NumDims],
+ std::vector<uint8_t> outValues,
+ float qScale = 10.f,
+ int qOffset = 0)
+{
+ return ComparisonTestImpl<NumDims, ArmnnInType>(
+ workloadFactory,
+ memoryManager,
+ descriptor,
+ shape0,
+ values0,
+ qScale,
+ qOffset,
+ shape1,
+ values1,
+ qScale,
+ qOffset,
+ outShape,
+ outValues,
+ qScale,
+ qOffset);
+}
diff --git a/src/backends/backendsCommon/test/layerTests/EqualTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/EqualTestImpl.cpp
index b0b613c137..a3d2b2796f 100644
--- a/src/backends/backendsCommon/test/layerTests/EqualTestImpl.cpp
+++ b/src/backends/backendsCommon/test/layerTests/EqualTestImpl.cpp
@@ -4,18 +4,10 @@
//
#include "EqualTestImpl.hpp"
-#include "ElementwiseTestImpl.hpp"
-#include <Half.hpp>
+#include "ComparisonTestImpl.hpp"
-template<>
-std::unique_ptr<armnn::IWorkload> CreateWorkload<armnn::EqualQueueDescriptor>(
- const armnn::IWorkloadFactory& workloadFactory,
- const armnn::WorkloadInfo& info,
- const armnn::EqualQueueDescriptor& descriptor)
-{
- return workloadFactory.CreateEqual(descriptor, info);
-}
+#include <Half.hpp>
LayerTestResult<uint8_t, 4> EqualSimpleTest(armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
@@ -39,9 +31,10 @@ LayerTestResult<uint8_t, 4> EqualSimpleTest(armnn::IWorkloadFactory& workloadFac
std::vector<uint8_t> output({ 1, 1, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1 });
- return ElementwiseTestHelper<4, armnn::EqualQueueDescriptor, armnn::DataType::Float32, armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::Float32>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Equal),
shape,
input0,
shape,
@@ -62,9 +55,10 @@ LayerTestResult<uint8_t, 4> EqualBroadcast1ElementTest(
std::vector<uint8_t> output({ 1, 0, 0, 0, 0, 0, 0, 0});
- return ElementwiseTestHelper<4, armnn::EqualQueueDescriptor, armnn::DataType::Float32, armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::Float32>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Equal),
shape0,
input0,
shape1,
@@ -88,9 +82,10 @@ LayerTestResult<uint8_t, 4> EqualBroadcast1DVectorTest(
std::vector<uint8_t> output({ 1, 1, 1, 0, 0, 0,
0, 0, 0, 0, 0, 0 });
- return ElementwiseTestHelper<4, armnn::EqualQueueDescriptor, armnn::DataType::Float32, armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::Float32>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Equal),
shape0,
input0,
shape1,
@@ -117,12 +112,10 @@ LayerTestResult<uint8_t, 4> EqualFloat16Test(
std::vector<uint8_t> output({ 0, 0, 0, 0, 1, 1, 1, 1,
1, 1, 1, 1, 0, 0, 0, 0 });
- return ElementwiseTestHelper<4,
- armnn::EqualQueueDescriptor,
- armnn::DataType::Float16,
- armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::Float16>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Equal),
shape,
input0,
shape,
@@ -148,12 +141,10 @@ LayerTestResult<uint8_t, 4> EqualBroadcast1ElementFloat16Test(
std::vector<uint8_t> output({ 1, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0 });
- return ElementwiseTestHelper<4,
- armnn::EqualQueueDescriptor,
- armnn::DataType::Float16,
- armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::Float16>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Equal),
shape0,
input0,
shape1,
@@ -179,12 +170,10 @@ LayerTestResult<uint8_t, 4> EqualBroadcast1DVectorFloat16Test(
std::vector<uint8_t> output({ 1, 0, 1, 0, 0, 0,
0, 0, 0, 0, 0, 0 });
- return ElementwiseTestHelper<4,
- armnn::EqualQueueDescriptor,
- armnn::DataType::Float16,
- armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::Float16>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Equal),
shape0,
input0,
shape1,
@@ -209,12 +198,10 @@ LayerTestResult<uint8_t, 4> EqualUint8Test(
std::vector<uint8_t> output({ 0, 0, 0, 0, 1, 1, 1, 1,
1, 1, 1, 1, 0, 0, 0, 0 });
- return ElementwiseTestHelper<4,
- armnn::EqualQueueDescriptor,
- armnn::DataType::QuantisedAsymm8,
- armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::QuantisedAsymm8>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Equal),
shape,
input0,
shape,
@@ -238,12 +225,10 @@ LayerTestResult<uint8_t, 4> EqualBroadcast1ElementUint8Test(
std::vector<uint8_t> output({ 1, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0 });
- return ElementwiseTestHelper<4,
- armnn::EqualQueueDescriptor,
- armnn::DataType::QuantisedAsymm8,
- armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::QuantisedAsymm8>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Equal),
shape0,
input0,
shape1,
@@ -267,12 +252,10 @@ LayerTestResult<uint8_t, 4> EqualBroadcast1DVectorUint8Test(
std::vector<uint8_t> output({ 1, 0, 1, 0, 0, 0,
0, 0, 0, 0, 0, 0 });
- return ElementwiseTestHelper<4,
- armnn::EqualQueueDescriptor,
- armnn::DataType::QuantisedAsymm8,
- armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::QuantisedAsymm8>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Equal),
shape0,
input0,
shape1,
diff --git a/src/backends/backendsCommon/test/layerTests/GreaterTestImpl.cpp b/src/backends/backendsCommon/test/layerTests/GreaterTestImpl.cpp
index 0148216285..271bc235a9 100644
--- a/src/backends/backendsCommon/test/layerTests/GreaterTestImpl.cpp
+++ b/src/backends/backendsCommon/test/layerTests/GreaterTestImpl.cpp
@@ -4,18 +4,10 @@
//
#include "GreaterTestImpl.hpp"
-#include "ElementwiseTestImpl.hpp"
-#include <Half.hpp>
+#include "ComparisonTestImpl.hpp"
-template<>
-std::unique_ptr<armnn::IWorkload> CreateWorkload<armnn::GreaterQueueDescriptor>(
- const armnn::IWorkloadFactory& workloadFactory,
- const armnn::WorkloadInfo& info,
- const armnn::GreaterQueueDescriptor& descriptor)
-{
- return workloadFactory.CreateGreater(descriptor, info);
-}
+#include <Half.hpp>
LayerTestResult<uint8_t, 4> GreaterSimpleTest(armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
@@ -45,12 +37,10 @@ LayerTestResult<uint8_t, 4> GreaterSimpleTest(armnn::IWorkloadFactory& workloadF
0, 0, 0, 0, 0, 0, 0, 0
};
- return ElementwiseTestHelper<4,
- armnn::GreaterQueueDescriptor,
- armnn::DataType::Float32,
- armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::Float32>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Greater),
shape,
input0,
shape,
@@ -71,12 +61,10 @@ LayerTestResult<uint8_t, 4> GreaterBroadcast1ElementTest(
std::vector<uint8_t> output = { 0, 1, 1, 1, 1, 1, 1, 1};
- return ElementwiseTestHelper<4,
- armnn::GreaterQueueDescriptor,
- armnn::DataType::Float32,
- armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::Float32>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Greater),
shape0,
input0,
shape1,
@@ -106,12 +94,10 @@ LayerTestResult<uint8_t, 4> GreaterBroadcast1DVectorTest(
1, 1, 1, 1, 1, 1
};
- return ElementwiseTestHelper<4,
- armnn::GreaterQueueDescriptor,
- armnn::DataType::Float32,
- armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::Float32>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Greater),
shape0,
input0,
shape1,
@@ -151,12 +137,10 @@ LayerTestResult<uint8_t, 4> GreaterFloat16Test(
0, 0, 0, 0, 0, 0, 0, 0
};
- return ElementwiseTestHelper<4,
- armnn::GreaterQueueDescriptor,
- armnn::DataType::Float16,
- armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4,armnn::DataType::Float16>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Greater),
shape,
input0,
shape,
@@ -179,12 +163,10 @@ LayerTestResult<uint8_t, 4> GreaterBroadcast1ElementFloat16Test(
std::vector<uint8_t> output = { 0, 1, 1, 1, 1, 1, 1, 1};
- return ElementwiseTestHelper<4,
- armnn::GreaterQueueDescriptor,
- armnn::DataType::Float16,
- armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::Float16>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Greater),
shape0,
input0,
shape1,
@@ -198,7 +180,7 @@ LayerTestResult<uint8_t, 4> GreaterBroadcast1DVectorFloat16Test(
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
{
using namespace half_float::literal;
-
+
const unsigned int shape0[] = { 1, 2, 2, 3 };
const unsigned int shape1[] = { 1, 1, 1, 3 };
@@ -216,12 +198,10 @@ LayerTestResult<uint8_t, 4> GreaterBroadcast1DVectorFloat16Test(
1, 1, 1, 1, 1, 1
};
- return ElementwiseTestHelper<4,
- armnn::GreaterQueueDescriptor,
- armnn::DataType::Float16,
- armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::Float16>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Greater),
shape0,
input0,
shape1,
@@ -255,12 +235,10 @@ LayerTestResult<uint8_t, 4> GreaterUint8Test(
1, 1, 1, 1, 0, 0, 0, 0
};
- return ElementwiseTestHelper<4,
- armnn::GreaterQueueDescriptor,
- armnn::DataType::QuantisedAsymm8,
- armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::QuantisedAsymm8>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Greater),
shape,
input0,
shape,
@@ -290,12 +268,10 @@ LayerTestResult<uint8_t, 4> GreaterBroadcast1ElementUint8Test(
1, 1, 1, 1, 1, 1
};
- return ElementwiseTestHelper<4,
- armnn::GreaterQueueDescriptor,
- armnn::DataType::QuantisedAsymm8,
- armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::QuantisedAsymm8>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Greater),
shape0,
input0,
shape1,
@@ -325,12 +301,10 @@ LayerTestResult<uint8_t, 4> GreaterBroadcast1DVectorUint8Test(
1, 1, 1, 1, 1, 1
};
- return ElementwiseTestHelper<4,
- armnn::GreaterQueueDescriptor,
- armnn::DataType::QuantisedAsymm8,
- armnn::DataType::Boolean>(
+ return ComparisonTestImpl<4, armnn::DataType::QuantisedAsymm8>(
workloadFactory,
memoryManager,
+ armnn::ComparisonDescriptor(armnn::ComparisonOperation::Greater),
shape0,
input0,
shape1,
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp
index c5ed8bff2a..bd2be57386 100644
--- a/src/backends/cl/ClLayerSupport.cpp
+++ b/src/backends/cl/ClLayerSupport.cpp
@@ -209,6 +209,24 @@ bool ClLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
descriptor);
}
+bool ClLayerSupport::IsComparisonSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ const ComparisonDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ if (descriptor.m_Operation == ComparisonOperation::Greater)
+ {
+ FORWARD_WORKLOAD_VALIDATE_FUNC(ClGreaterWorkloadValidate,
+ reasonIfUnsupported,
+ input0,
+ input1,
+ output);
+ }
+
+ return false;
+}
+
bool ClLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
const TensorInfo& output,
const ConcatDescriptor& descriptor,
@@ -398,11 +416,8 @@ bool ClLayerSupport::IsGreaterSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- FORWARD_WORKLOAD_VALIDATE_FUNC(ClGreaterWorkloadValidate,
- reasonIfUnsupported,
- input0,
- input1,
- output);
+ ComparisonDescriptor descriptor(ComparisonOperation::Greater);
+ return IsComparisonSupported(input0, input1, output, descriptor, reasonIfUnsupported);
}
bool ClLayerSupport::IsInputSupported(const TensorInfo& input,
diff --git a/src/backends/cl/ClLayerSupport.hpp b/src/backends/cl/ClLayerSupport.hpp
index 59e849316f..26eb42e092 100644
--- a/src/backends/cl/ClLayerSupport.hpp
+++ b/src/backends/cl/ClLayerSupport.hpp
@@ -40,6 +40,12 @@ public:
const BatchToSpaceNdDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsComparisonSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& ouput,
+ const ComparisonDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
bool IsConcatSupported(const std::vector<const TensorInfo*> inputs,
const TensorInfo& output,
const ConcatDescriptor& descriptor,
@@ -102,6 +108,7 @@ public:
const FullyConnectedDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead")
bool IsGreaterSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& ouput,
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp
index c427ae7e12..04e09f4ff1 100644
--- a/src/backends/cl/ClWorkloadFactory.cpp
+++ b/src/backends/cl/ClWorkloadFactory.cpp
@@ -157,6 +157,20 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreateBatchToSpaceNd(const BatchTo
return MakeWorkload<ClBatchToSpaceNdWorkload>(descriptor, info);
}
+std::unique_ptr<IWorkload> ClWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ if (descriptor.m_Parameters.m_Operation == ComparisonOperation::Greater)
+ {
+ GreaterQueueDescriptor greaterQueueDescriptor;
+ greaterQueueDescriptor.m_Inputs = descriptor.m_Inputs;
+ greaterQueueDescriptor.m_Outputs = descriptor.m_Outputs;
+
+ return MakeWorkload<ClGreaterFloat32Workload, ClGreaterUint8Workload>(greaterQueueDescriptor, info);
+ }
+ return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+}
+
std::unique_ptr<IWorkload> ClWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
@@ -230,7 +244,12 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreateDivision(const DivisionQueue
std::unique_ptr<IWorkload> ClWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+ boost::ignore_unused(descriptor);
+
+ ComparisonQueueDescriptor comparisonDescriptor;
+ comparisonDescriptor.m_Parameters = ComparisonDescriptor(ComparisonOperation::Equal);
+
+ return CreateComparison(comparisonDescriptor, info);
}
std::unique_ptr<IWorkload> ClWorkloadFactory::CreateFakeQuantization(
@@ -261,7 +280,12 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreateGather(const GatherQueueDesc
std::unique_ptr<IWorkload> ClWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkload<ClGreaterFloat32Workload, ClGreaterUint8Workload>(descriptor, info);
+ boost::ignore_unused(descriptor);
+
+ ComparisonQueueDescriptor comparisonDescriptor;
+ comparisonDescriptor.m_Parameters = ComparisonDescriptor(ComparisonOperation::Greater);
+
+ return CreateComparison(comparisonDescriptor, info);
}
std::unique_ptr<IWorkload> ClWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor,
diff --git a/src/backends/cl/ClWorkloadFactory.hpp b/src/backends/cl/ClWorkloadFactory.hpp
index 9dbc615a4e..1cae6e1faf 100644
--- a/src/backends/cl/ClWorkloadFactory.hpp
+++ b/src/backends/cl/ClWorkloadFactory.hpp
@@ -53,6 +53,9 @@ public:
std::unique_ptr<IWorkload> CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateComparison(const ComparisonQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
+
std::unique_ptr<IWorkload> CreateConcat(const ConcatQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
diff --git a/src/backends/cl/test/ClEndToEndTests.cpp b/src/backends/cl/test/ClEndToEndTests.cpp
index 59d26edf22..26f15b77da 100644
--- a/src/backends/cl/test/ClEndToEndTests.cpp
+++ b/src/backends/cl/test/ClEndToEndTests.cpp
@@ -6,7 +6,7 @@
#include <backendsCommon/test/EndToEndTestImpl.hpp>
#include <backendsCommon/test/AbsEndToEndTestImpl.hpp>
-#include <backendsCommon/test/ArithmeticTestImpl.hpp>
+#include <backendsCommon/test/ComparisonEndToEndTestImpl.hpp>
#include <backendsCommon/test/ConcatEndToEndTestImpl.hpp>
#include <backendsCommon/test/DepthToSpaceEndToEndTestImpl.hpp>
#include <backendsCommon/test/DequantizeEndToEndTestImpl.hpp>
@@ -122,9 +122,9 @@ BOOST_AUTO_TEST_CASE(ClGreaterSimpleEndToEndTest)
const std::vector<uint8_t> expectedOutput({ 0, 0, 0, 0, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0 });
- ArithmeticSimpleEndToEnd<armnn::DataType::Float32, armnn::DataType::Boolean>(defaultBackends,
- LayerType::Greater,
- expectedOutput);
+ ComparisonSimpleEndToEnd<armnn::DataType::Float32>(defaultBackends,
+ ComparisonOperation::Greater,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(ClGreaterSimpleEndToEndUint8Test)
@@ -132,9 +132,9 @@ BOOST_AUTO_TEST_CASE(ClGreaterSimpleEndToEndUint8Test)
const std::vector<uint8_t> expectedOutput({ 0, 0, 0, 0, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0 });
- ArithmeticSimpleEndToEnd<armnn::DataType::QuantisedAsymm8, armnn::DataType::Boolean>(defaultBackends,
- LayerType::Greater,
- expectedOutput);
+ ComparisonSimpleEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends,
+ ComparisonOperation::Greater,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(ClGreaterBroadcastEndToEndTest)
@@ -142,9 +142,9 @@ BOOST_AUTO_TEST_CASE(ClGreaterBroadcastEndToEndTest)
const std::vector<uint8_t> expectedOutput({ 0, 1, 0, 0, 0, 1,
1, 1, 1, 1, 1, 1 });
- ArithmeticBroadcastEndToEnd<armnn::DataType::Float32, armnn::DataType::Boolean>(defaultBackends,
- LayerType::Greater,
- expectedOutput);
+ ComparisonBroadcastEndToEnd<armnn::DataType::Float32>(defaultBackends,
+ ComparisonOperation::Greater,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(ClGreaterBroadcastEndToEndUint8Test)
@@ -152,9 +152,9 @@ BOOST_AUTO_TEST_CASE(ClGreaterBroadcastEndToEndUint8Test)
const std::vector<uint8_t> expectedOutput({ 0, 1, 0, 0, 0, 1,
1, 1, 1, 1, 1, 1 });
- ArithmeticBroadcastEndToEnd<armnn::DataType::QuantisedAsymm8, armnn::DataType::Boolean>(defaultBackends,
- LayerType::Greater,
- expectedOutput);
+ ComparisonBroadcastEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends,
+ ComparisonOperation::Greater,
+ expectedOutput);
}
// InstanceNormalization
diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp
index 270cb6264f..cc96f63c1a 100644
--- a/src/backends/neon/NeonLayerSupport.cpp
+++ b/src/backends/neon/NeonLayerSupport.cpp
@@ -180,6 +180,24 @@ bool NeonLayerSupport::IsBatchNormalizationSupported(const TensorInfo& input,
descriptor);
}
+bool NeonLayerSupport::IsComparisonSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ const ComparisonDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ if (descriptor.m_Operation == ComparisonOperation::Greater)
+ {
+ FORWARD_WORKLOAD_VALIDATE_FUNC(NeonGreaterWorkloadValidate,
+ reasonIfUnsupported,
+ input0,
+ input1,
+ output);
+ }
+
+ return false;
+}
+
bool NeonLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
const TensorInfo& output,
const ConcatDescriptor& descriptor,
@@ -354,11 +372,8 @@ bool NeonLayerSupport::IsGreaterSupported(const armnn::TensorInfo& input0,
const armnn::TensorInfo& output,
armnn::Optional<std::string&> reasonIfUnsupported) const
{
- FORWARD_WORKLOAD_VALIDATE_FUNC(NeonGreaterWorkloadValidate,
- reasonIfUnsupported,
- input0,
- input1,
- output);
+ ComparisonDescriptor descriptor(ComparisonOperation::Greater);
+ return IsComparisonSupported(input0, input1, output, descriptor, reasonIfUnsupported);
}
bool NeonLayerSupport::IsInputSupported(const TensorInfo& input,
diff --git a/src/backends/neon/NeonLayerSupport.hpp b/src/backends/neon/NeonLayerSupport.hpp
index d6a24ad43b..9c1ca2a729 100644
--- a/src/backends/neon/NeonLayerSupport.hpp
+++ b/src/backends/neon/NeonLayerSupport.hpp
@@ -40,6 +40,12 @@ public:
const BatchNormalizationDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsComparisonSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ const ComparisonDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
bool IsConcatSupported(const std::vector<const TensorInfo*> inputs,
const TensorInfo& output,
const ConcatDescriptor& descriptor,
@@ -98,6 +104,7 @@ public:
const FullyConnectedDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead")
bool IsGreaterSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
diff --git a/src/backends/neon/NeonWorkloadFactory.cpp b/src/backends/neon/NeonWorkloadFactory.cpp
index 5bd8f293c5..dda1d7a132 100644
--- a/src/backends/neon/NeonWorkloadFactory.cpp
+++ b/src/backends/neon/NeonWorkloadFactory.cpp
@@ -131,6 +131,20 @@ std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateBatchToSpaceNd(const Batch
return MakeWorkloadHelper<NullWorkload, NullWorkload>(descriptor, info);
}
+std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ if (descriptor.m_Parameters.m_Operation == ComparisonOperation::Greater)
+ {
+ GreaterQueueDescriptor greaterQueueDescriptor;
+ greaterQueueDescriptor.m_Inputs = descriptor.m_Inputs;
+ greaterQueueDescriptor.m_Outputs = descriptor.m_Outputs;
+
+ return MakeWorkloadHelper<NeonGreaterFloat32Workload, NeonGreaterUint8Workload>(greaterQueueDescriptor, info);
+ }
+ return MakeWorkloadHelper<NullWorkload, NullWorkload>(descriptor, info);
+}
+
std::unique_ptr<armnn::IWorkload> NeonWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
@@ -203,7 +217,12 @@ std::unique_ptr<armnn::IWorkload> NeonWorkloadFactory::CreateDivision(
std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkloadHelper<NullWorkload, NullWorkload>(descriptor, info);
+ boost::ignore_unused(descriptor);
+
+ ComparisonQueueDescriptor comparisonDescriptor;
+ comparisonDescriptor.m_Parameters = ComparisonDescriptor(ComparisonOperation::Equal);
+
+ return CreateComparison(comparisonDescriptor, info);
}
std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateFakeQuantization(
@@ -235,7 +254,12 @@ std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateGather(const armnn::Gather
std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkloadHelper<NeonGreaterFloat32Workload, NeonGreaterUint8Workload>(descriptor, info);
+ boost::ignore_unused(descriptor);
+
+ ComparisonQueueDescriptor comparisonDescriptor;
+ comparisonDescriptor.m_Parameters = ComparisonDescriptor(ComparisonOperation::Greater);
+
+ return CreateComparison(comparisonDescriptor, info);
}
std::unique_ptr<IWorkload> NeonWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor,
diff --git a/src/backends/neon/NeonWorkloadFactory.hpp b/src/backends/neon/NeonWorkloadFactory.hpp
index 9546164d36..5bee771528 100644
--- a/src/backends/neon/NeonWorkloadFactory.hpp
+++ b/src/backends/neon/NeonWorkloadFactory.hpp
@@ -57,6 +57,9 @@ public:
std::unique_ptr<IWorkload> CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
const WorkloadInfo& Info) const override;
+ std::unique_ptr<IWorkload> CreateComparison(const ComparisonQueueDescriptor& descriptor,
+ const WorkloadInfo& Info) const override;
+
std::unique_ptr<IWorkload> CreateConcat(const ConcatQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
@@ -90,6 +93,7 @@ public:
std::unique_ptr<IWorkload> CreateDivision(const DivisionQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
+ ARMNN_DEPRECATED_MSG("Use CreateComparison instead")
std::unique_ptr<IWorkload> CreateEqual(const EqualQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
@@ -105,6 +109,7 @@ public:
std::unique_ptr<IWorkload> CreateGather(const GatherQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
+ ARMNN_DEPRECATED_MSG("Use CreateComparison instead")
std::unique_ptr<IWorkload> CreateGreater(const GreaterQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
diff --git a/src/backends/neon/test/NeonEndToEndTests.cpp b/src/backends/neon/test/NeonEndToEndTests.cpp
index 88f7ae7b5d..5146a598c7 100644
--- a/src/backends/neon/test/NeonEndToEndTests.cpp
+++ b/src/backends/neon/test/NeonEndToEndTests.cpp
@@ -6,7 +6,7 @@
#include <backendsCommon/test/EndToEndTestImpl.hpp>
#include <backendsCommon/test/AbsEndToEndTestImpl.hpp>
-#include <backendsCommon/test/ArithmeticTestImpl.hpp>
+#include <backendsCommon/test/ComparisonEndToEndTestImpl.hpp>
#include <backendsCommon/test/ConcatEndToEndTestImpl.hpp>
#include <backendsCommon/test/DepthToSpaceEndToEndTestImpl.hpp>
#include <backendsCommon/test/DequantizeEndToEndTestImpl.hpp>
@@ -80,9 +80,9 @@ BOOST_AUTO_TEST_CASE(NeonGreaterSimpleEndToEndTest)
const std::vector<uint8_t> expectedOutput({ 0, 0, 0, 0, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0 });
- ArithmeticSimpleEndToEnd<armnn::DataType::Float32, armnn::DataType::Boolean>(defaultBackends,
- LayerType::Greater,
- expectedOutput);
+ ComparisonSimpleEndToEnd<armnn::DataType::Float32>(defaultBackends,
+ ComparisonOperation::Greater,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(NeonGreaterSimpleEndToEndUint8Test)
@@ -90,9 +90,9 @@ BOOST_AUTO_TEST_CASE(NeonGreaterSimpleEndToEndUint8Test)
const std::vector<uint8_t> expectedOutput({ 0, 0, 0, 0, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0 });
- ArithmeticSimpleEndToEnd<armnn::DataType::QuantisedAsymm8, armnn::DataType::Boolean>(defaultBackends,
- LayerType::Greater,
- expectedOutput);
+ ComparisonSimpleEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends,
+ ComparisonOperation::Greater,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(NeonGreaterBroadcastEndToEndTest)
@@ -100,9 +100,9 @@ BOOST_AUTO_TEST_CASE(NeonGreaterBroadcastEndToEndTest)
const std::vector<uint8_t> expectedOutput({ 0, 1, 0, 0, 0, 1,
1, 1, 1, 1, 1, 1 });
- ArithmeticBroadcastEndToEnd<armnn::DataType::Float32, armnn::DataType::Boolean>(defaultBackends,
- LayerType::Greater,
- expectedOutput);
+ ComparisonBroadcastEndToEnd<armnn::DataType::Float32>(defaultBackends,
+ ComparisonOperation::Greater,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(NeonGreaterBroadcastEndToEndUint8Test)
@@ -110,9 +110,9 @@ BOOST_AUTO_TEST_CASE(NeonGreaterBroadcastEndToEndUint8Test)
const std::vector<uint8_t> expectedOutput({ 0, 1, 0, 0, 0, 1,
1, 1, 1, 1, 1, 1 });
- ArithmeticBroadcastEndToEnd<armnn::DataType::QuantisedAsymm8, armnn::DataType::Boolean>(defaultBackends,
- LayerType::Greater,
- expectedOutput);
+ ComparisonBroadcastEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends,
+ ComparisonOperation::Greater,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(NeonConcatEndToEndDim0Test)
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 9342b29f47..c65886ba4d 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -308,6 +308,35 @@ bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
return supported;
}
+bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ const ComparisonDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ boost::ignore_unused(descriptor);
+
+ std::array<DataType, 4> supportedInputTypes =
+ {
+ DataType::Float32,
+ DataType::Float16,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ bool supported = true;
+ supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
+ "Reference comparison: input 0 is not a supported type");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
+ "Reference comparison: input 0 and Input 1 types are mismatched");
+
+ supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
+ "Reference comparison: output is not of type Boolean");
+
+ return supported;
+}
+
bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
const TensorInfo& output,
const ConcatDescriptor& descriptor,
@@ -644,29 +673,11 @@ bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- bool supported = true;
-
- std::array<DataType,4> supportedTypes =
- {
- DataType::Float32,
- DataType::Float16,
- DataType::QuantisedAsymm8,
- DataType::QuantisedSymm16
- };
-
- supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
- "Reference equal: input 0 is not a supported type.");
-
- supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
- "Reference equal: input 1 is not a supported type.");
-
- supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
- "Reference equal: input 0 and Input 1 types are mismatched");
-
- supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
- "Reference equal: shapes are not suitable for implicit broadcast.");
-
- return supported;
+ return IsComparisonSupported(input0,
+ input1,
+ output,
+ ComparisonDescriptor(ComparisonOperation::Equal),
+ reasonIfUnsupported);
}
bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
@@ -802,29 +813,11 @@ bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- bool supported = true;
-
- std::array<DataType,4> supportedTypes =
- {
- DataType::Float32,
- DataType::Float16,
- DataType::QuantisedAsymm8,
- DataType::QuantisedSymm16
- };
-
- supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
- "Reference greater: input 0 is not a supported type.");
-
- supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
- "Reference greater: input 1 is not a supported type.");
-
- supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
- "Reference greater: input 0 and Input 1 types are mismatched");
-
- supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
- "Reference greater: shapes are not suitable for implicit broadcast.");
-
- return supported;
+ return IsComparisonSupported(input0,
+ input1,
+ output,
+ ComparisonDescriptor(ComparisonOperation::Greater),
+ reasonIfUnsupported);
}
bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index 5c71e8d337..04b355ee0a 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -45,6 +45,12 @@ public:
const BatchToSpaceNdDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsComparisonSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ const ComparisonDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
bool IsConcatSupported(const std::vector<const TensorInfo*> inputs,
const TensorInfo& output,
const ConcatDescriptor& descriptor,
@@ -106,6 +112,7 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead")
bool IsEqualSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
@@ -131,6 +138,7 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead")
bool IsGreaterSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 1f6d1d7e8b..c2cb51abf3 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -131,6 +131,12 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateBatchToSpaceNd(const BatchT
return std::make_unique<RefBatchToSpaceNdWorkload>(descriptor, info);
}
+std::unique_ptr<IWorkload> RefWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ return std::make_unique<RefComparisonWorkload>(descriptor, info);
+}
+
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
@@ -208,7 +214,10 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDivision(const DivisionQueu
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<RefEqualWorkload>(descriptor, info);
+ ComparisonQueueDescriptor comparisonDescriptor;
+ comparisonDescriptor.m_Parameters.m_Operation = ComparisonOperation::Equal;
+
+ return CreateComparison(comparisonDescriptor, info);
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFakeQuantization(
@@ -240,7 +249,10 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateGather(const GatherQueueDes
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<RefGreaterWorkload>(descriptor, info);
+ ComparisonQueueDescriptor comparisonDescriptor;
+ comparisonDescriptor.m_Parameters.m_Operation = ComparisonOperation::Greater;
+
+ return CreateComparison(comparisonDescriptor, info);
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor,
diff --git a/src/backends/reference/RefWorkloadFactory.hpp b/src/backends/reference/RefWorkloadFactory.hpp
index 41e9b28ea2..7b73d5b21f 100644
--- a/src/backends/reference/RefWorkloadFactory.hpp
+++ b/src/backends/reference/RefWorkloadFactory.hpp
@@ -78,6 +78,9 @@ public:
std::unique_ptr<IWorkload> CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateComparison(const ComparisonQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
+
std::unique_ptr<IWorkload> CreateConcat(const ConcatQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
@@ -111,6 +114,7 @@ public:
std::unique_ptr<IWorkload> CreateDivision(const DivisionQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
+ ARMNN_DEPRECATED_MSG("Use CreateComparison instead")
std::unique_ptr<IWorkload> CreateEqual(const EqualQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
@@ -126,6 +130,7 @@ public:
std::unique_ptr<IWorkload> CreateGather(const GatherQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
+ ARMNN_DEPRECATED_MSG("Use CreateComparison instead")
std::unique_ptr<IWorkload> CreateGreater(const GreaterQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index 49b07a41d2..7e97acdee2 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -47,6 +47,7 @@ BACKEND_SOURCES := \
workloads/RefArgMinMaxWorkload.cpp \
workloads/RefBatchNormalizationWorkload.cpp \
workloads/RefBatchToSpaceNdWorkload.cpp \
+ workloads/RefComparisonWorkload.cpp \
workloads/RefConcatWorkload.cpp \
workloads/RefConstantWorkload.cpp \
workloads/RefConvertFp16ToFp32Workload.cpp \
diff --git a/src/backends/reference/test/RefEndToEndTests.cpp b/src/backends/reference/test/RefEndToEndTests.cpp
index 370ef6599b..1968e4da7e 100644
--- a/src/backends/reference/test/RefEndToEndTests.cpp
+++ b/src/backends/reference/test/RefEndToEndTests.cpp
@@ -6,8 +6,8 @@
#include <backendsCommon/test/EndToEndTestImpl.hpp>
#include <backendsCommon/test/AbsEndToEndTestImpl.hpp>
-#include <backendsCommon/test/ArithmeticTestImpl.hpp>
#include <backendsCommon/test/BatchToSpaceNdEndToEndTestImpl.hpp>
+#include <backendsCommon/test/ComparisonEndToEndTestImpl.hpp>
#include <backendsCommon/test/ConcatEndToEndTestImpl.hpp>
#include <backendsCommon/test/DepthToSpaceEndToEndTestImpl.hpp>
#include <backendsCommon/test/DequantizeEndToEndTestImpl.hpp>
@@ -348,9 +348,9 @@ BOOST_AUTO_TEST_CASE(RefEqualSimpleEndToEndTest)
const std::vector<uint8_t> expectedOutput({ 1, 1, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1 });
- ArithmeticSimpleEndToEnd<armnn::DataType::Float32, armnn::DataType::Boolean>(defaultBackends,
- LayerType::Equal,
- expectedOutput);
+ ComparisonSimpleEndToEnd<armnn::DataType::Float32>(defaultBackends,
+ ComparisonOperation::Equal,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(RefGreaterSimpleEndToEndTest)
@@ -358,9 +358,9 @@ BOOST_AUTO_TEST_CASE(RefGreaterSimpleEndToEndTest)
const std::vector<uint8_t> expectedOutput({ 0, 0, 0, 0, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0 });
- ArithmeticSimpleEndToEnd<armnn::DataType::Float32, armnn::DataType::Boolean>(defaultBackends,
- LayerType::Greater,
- expectedOutput);
+ ComparisonSimpleEndToEnd<armnn::DataType::Float32>(defaultBackends,
+ ComparisonOperation::Greater,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(RefEqualSimpleEndToEndUint8Test)
@@ -368,9 +368,9 @@ BOOST_AUTO_TEST_CASE(RefEqualSimpleEndToEndUint8Test)
const std::vector<uint8_t> expectedOutput({ 1, 1, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1 });
- ArithmeticSimpleEndToEnd<armnn::DataType::QuantisedAsymm8, armnn::DataType::Boolean>(defaultBackends,
- LayerType::Equal,
- expectedOutput);
+ ComparisonSimpleEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends,
+ ComparisonOperation::Equal,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(RefGreaterSimpleEndToEndUint8Test)
@@ -378,9 +378,9 @@ BOOST_AUTO_TEST_CASE(RefGreaterSimpleEndToEndUint8Test)
const std::vector<uint8_t> expectedOutput({ 0, 0, 0, 0, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0 });
- ArithmeticSimpleEndToEnd<armnn::DataType::QuantisedAsymm8, armnn::DataType::Boolean>(defaultBackends,
- LayerType::Greater,
- expectedOutput);
+ ComparisonSimpleEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends,
+ ComparisonOperation::Greater,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(RefEqualBroadcastEndToEndTest)
@@ -388,9 +388,9 @@ BOOST_AUTO_TEST_CASE(RefEqualBroadcastEndToEndTest)
const std::vector<uint8_t> expectedOutput({ 1, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0 });
- ArithmeticBroadcastEndToEnd<armnn::DataType::Float32, armnn::DataType::Boolean>(defaultBackends,
- LayerType::Equal,
- expectedOutput);
+ ComparisonBroadcastEndToEnd<armnn::DataType::Float32>(defaultBackends,
+ ComparisonOperation::Equal,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(RefGreaterBroadcastEndToEndTest)
@@ -398,9 +398,9 @@ BOOST_AUTO_TEST_CASE(RefGreaterBroadcastEndToEndTest)
const std::vector<uint8_t> expectedOutput({ 0, 1, 0, 0, 0, 1,
1, 1, 1, 1, 1, 1 });
- ArithmeticBroadcastEndToEnd<armnn::DataType::Float32, armnn::DataType::Boolean>(defaultBackends,
- LayerType::Greater,
- expectedOutput);
+ ComparisonBroadcastEndToEnd<armnn::DataType::Float32>(defaultBackends,
+ ComparisonOperation::Greater,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(RefEqualBroadcastEndToEndUint8Test)
@@ -408,9 +408,9 @@ BOOST_AUTO_TEST_CASE(RefEqualBroadcastEndToEndUint8Test)
const std::vector<uint8_t > expectedOutput({ 1, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0 });
- ArithmeticBroadcastEndToEnd<armnn::DataType::QuantisedAsymm8, armnn::DataType::Boolean>(defaultBackends,
- LayerType::Equal,
- expectedOutput);
+ ComparisonBroadcastEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends,
+ ComparisonOperation::Equal,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(RefGreaterBroadcastEndToEndUint8Test)
@@ -418,9 +418,9 @@ BOOST_AUTO_TEST_CASE(RefGreaterBroadcastEndToEndUint8Test)
const std::vector<uint8_t> expectedOutput({ 0, 1, 0, 0, 0, 1,
1, 1, 1, 1, 1, 1 });
- ArithmeticBroadcastEndToEnd<armnn::DataType::QuantisedAsymm8, armnn::DataType::Boolean>(defaultBackends,
- LayerType::Greater,
- expectedOutput);
+ ComparisonBroadcastEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends,
+ ComparisonOperation::Greater,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(RefBatchToSpaceNdEndToEndFloat32NHWCTest)
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index b8eb95c729..7844518620 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -63,6 +63,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefBatchNormalizationWorkload.hpp
RefBatchToSpaceNdWorkload.cpp
RefBatchToSpaceNdWorkload.hpp
+ RefComparisonWorkload.cpp
+ RefComparisonWorkload.hpp
RefConcatWorkload.cpp
RefConcatWorkload.hpp
RefConstantWorkload.cpp
diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp
index 7a5c071f70..888037f9a6 100644
--- a/src/backends/reference/workloads/ElementwiseFunction.cpp
+++ b/src/backends/reference/workloads/ElementwiseFunction.cpp
@@ -32,6 +32,11 @@ template struct armnn::ElementwiseFunction<std::multiplies<float>>;
template struct armnn::ElementwiseFunction<std::divides<float>>;
template struct armnn::ElementwiseFunction<armnn::maximum<float>>;
template struct armnn::ElementwiseFunction<armnn::minimum<float>>;
+
+// Comparison
template struct armnn::ElementwiseFunction<std::equal_to<float>>;
template struct armnn::ElementwiseFunction<std::greater<float>>;
-
+template struct armnn::ElementwiseFunction<std::greater_equal<float>>;
+template struct armnn::ElementwiseFunction<std::less<float>>;
+template struct armnn::ElementwiseFunction<std::less_equal<float>>;
+template struct armnn::ElementwiseFunction<std::not_equal_to<float>>;
diff --git a/src/backends/reference/workloads/RefComparisonWorkload.cpp b/src/backends/reference/workloads/RefComparisonWorkload.cpp
new file mode 100644
index 0000000000..60446226be
--- /dev/null
+++ b/src/backends/reference/workloads/RefComparisonWorkload.cpp
@@ -0,0 +1,102 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefComparisonWorkload.hpp"
+
+#include "Decoders.hpp"
+#include "ElementwiseFunction.hpp"
+#include "Encoders.hpp"
+#include "RefWorkloadUtils.hpp"
+
+#include <Profiling.hpp>
+
+#include <armnn/TypesUtils.hpp>
+
+#include <functional>
+
+namespace armnn
+{
+
+RefComparisonWorkload::RefComparisonWorkload(const ComparisonQueueDescriptor& desc,
+ const WorkloadInfo& info)
+ : BaseWorkload<ComparisonQueueDescriptor>(desc, info)
+{}
+
+void RefComparisonWorkload::PostAllocationConfigure()
+{
+ const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ m_Input0 = MakeDecoder<InType>(inputInfo0);
+ m_Input1 = MakeDecoder<InType>(inputInfo1);
+
+ m_Output = MakeEncoder<OutType>(outputInfo);
+}
+
+void RefComparisonWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefComparisonWorkload_Execute");
+
+ const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ const TensorShape& inShape0 = inputInfo0.GetShape();
+ const TensorShape& inShape1 = inputInfo1.GetShape();
+ const TensorShape& outShape = outputInfo.GetShape();
+
+ m_Input0->Reset(m_Data.m_Inputs[0]->Map());
+ m_Input1->Reset(m_Data.m_Inputs[1]->Map());
+ m_Output->Reset(m_Data.m_Outputs[0]->Map());
+
+ using EqualFunction = ElementwiseFunction<std::equal_to<InType>>;
+ using GreaterFunction = ElementwiseFunction<std::greater<InType>>;
+ using GreaterOrEqualFunction = ElementwiseFunction<std::greater_equal<InType>>;
+ using LessFunction = ElementwiseFunction<std::less<InType>>;
+ using LessOrEqualFunction = ElementwiseFunction<std::less_equal<InType>>;
+ using NotEqualFunction = ElementwiseFunction<std::not_equal_to<InType>>;
+
+ switch (m_Data.m_Parameters.m_Operation)
+ {
+ case ComparisonOperation::Equal:
+ {
+ EqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
+ break;
+ }
+ case ComparisonOperation::Greater:
+ {
+ GreaterFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
+ break;
+ }
+ case ComparisonOperation::GreaterOrEqual:
+ {
+ GreaterOrEqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
+ break;
+ }
+ case ComparisonOperation::Less:
+ {
+ LessFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
+ break;
+ }
+ case ComparisonOperation::LessOrEqual:
+ {
+ LessOrEqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
+ break;
+ }
+ case ComparisonOperation::NotEqual:
+ {
+ NotEqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
+ break;
+ }
+ default:
+ {
+ throw InvalidArgumentException(std::string("Unsupported comparison operation ") +
+ GetComparisonOperationAsCString(m_Data.m_Parameters.m_Operation), CHECK_LOCATION());
+ }
+ }
+}
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefComparisonWorkload.hpp b/src/backends/reference/workloads/RefComparisonWorkload.hpp
new file mode 100644
index 0000000000..a19e4a0540
--- /dev/null
+++ b/src/backends/reference/workloads/RefComparisonWorkload.hpp
@@ -0,0 +1,34 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "BaseIterator.hpp"
+
+#include <backendsCommon/Workload.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+
+namespace armnn
+{
+
+class RefComparisonWorkload : public BaseWorkload<ComparisonQueueDescriptor>
+{
+public:
+ using BaseWorkload<ComparisonQueueDescriptor>::m_Data;
+
+ RefComparisonWorkload(const ComparisonQueueDescriptor& descriptor, const WorkloadInfo& info);
+ void PostAllocationConfigure() override;
+ void Execute() const override;
+
+private:
+ using InType = float;
+ using OutType = bool;
+
+ std::unique_ptr<Decoder<InType>> m_Input0;
+ std::unique_ptr<Decoder<InType>> m_Input1;
+ std::unique_ptr<Encoder<OutType>> m_Output;
+};
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
index 6431348bc2..7e02f032ef 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
@@ -86,11 +86,3 @@ template class armnn::RefElementwiseWorkload<armnn::maximum<float>,
template class armnn::RefElementwiseWorkload<armnn::minimum<float>,
armnn::MinimumQueueDescriptor,
armnn::StringMapping::RefMinimumWorkload_Execute>;
-
-template class armnn::RefElementwiseWorkload<std::equal_to<float>,
- armnn::EqualQueueDescriptor,
- armnn::StringMapping::RefEqualWorkload_Execute>;
-
-template class armnn::RefElementwiseWorkload<std::greater<float>,
- armnn::GreaterQueueDescriptor,
- armnn::StringMapping::RefGreaterWorkload_Execute>;
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
index 651942e9e5..ee0d80b172 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
@@ -65,13 +65,4 @@ using RefMinimumWorkload =
MinimumQueueDescriptor,
StringMapping::RefMinimumWorkload_Execute>;
-using RefEqualWorkload =
- RefElementwiseWorkload<std::equal_to<float>,
- armnn::EqualQueueDescriptor,
- armnn::StringMapping::RefEqualWorkload_Execute>;
-
-using RefGreaterWorkload =
- RefElementwiseWorkload<std::greater<float>,
- armnn::GreaterQueueDescriptor,
- armnn::StringMapping::RefGreaterWorkload_Execute>;
} // armnn
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 79d1935823..1f9ad4a19a 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -20,6 +20,7 @@
#include "RefArgMinMaxWorkload.hpp"
#include "RefBatchNormalizationWorkload.hpp"
#include "RefBatchToSpaceNdWorkload.hpp"
+#include "RefComparisonWorkload.hpp"
#include "RefConvolution2dWorkload.hpp"
#include "RefConstantWorkload.hpp"
#include "RefConcatWorkload.hpp"
diff --git a/src/backends/reference/workloads/StringMapping.hpp b/src/backends/reference/workloads/StringMapping.hpp
index 073a5a6833..1654b78088 100644
--- a/src/backends/reference/workloads/StringMapping.hpp
+++ b/src/backends/reference/workloads/StringMapping.hpp
@@ -18,9 +18,7 @@ struct StringMapping
public:
enum Id {
RefAdditionWorkload_Execute,
- RefEqualWorkload_Execute,
RefDivisionWorkload_Execute,
- RefGreaterWorkload_Execute,
RefMaximumWorkload_Execute,
RefMinimumWorkload_Execute,
RefMultiplicationWorkload_Execute,
@@ -40,8 +38,6 @@ private:
{
m_Strings[RefAdditionWorkload_Execute] = "RefAdditionWorkload_Execute";
m_Strings[RefDivisionWorkload_Execute] = "RefDivisionWorkload_Execute";
- m_Strings[RefEqualWorkload_Execute] = "RefEqualWorkload_Execute";
- m_Strings[RefGreaterWorkload_Execute] = "RefGreaterWorkload_Execute";
m_Strings[RefMaximumWorkload_Execute] = "RefMaximumWorkload_Execute";
m_Strings[RefMinimumWorkload_Execute] = "RefMinimumWorkload_Execute";
m_Strings[RefMultiplicationWorkload_Execute] = "RefMultiplicationWorkload_Execute";