aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp85
-rw-r--r--src/backends/reference/RefLayerSupport.hpp8
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp16
-rw-r--r--src/backends/reference/RefWorkloadFactory.hpp5
-rw-r--r--src/backends/reference/backend.mk1
-rw-r--r--src/backends/reference/test/RefEndToEndTests.cpp50
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt2
-rw-r--r--src/backends/reference/workloads/ElementwiseFunction.cpp7
-rw-r--r--src/backends/reference/workloads/RefComparisonWorkload.cpp102
-rw-r--r--src/backends/reference/workloads/RefComparisonWorkload.hpp34
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.cpp8
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.hpp9
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp1
-rw-r--r--src/backends/reference/workloads/StringMapping.hpp4
14 files changed, 237 insertions, 95 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 9342b29f47..c65886ba4d 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -308,6 +308,35 @@ bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input,
return supported;
}
+bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ const ComparisonDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ boost::ignore_unused(descriptor);
+
+ std::array<DataType, 4> supportedInputTypes =
+ {
+ DataType::Float32,
+ DataType::Float16,
+ DataType::QuantisedAsymm8,
+ DataType::QuantisedSymm16
+ };
+
+ bool supported = true;
+ supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported,
+ "Reference comparison: input 0 is not a supported type");
+
+ supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
+ "Reference comparison: input 0 and Input 1 types are mismatched");
+
+ supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported,
+ "Reference comparison: output is not of type Boolean");
+
+ return supported;
+}
+
bool RefLayerSupport::IsConcatSupported(const std::vector<const TensorInfo*> inputs,
const TensorInfo& output,
const ConcatDescriptor& descriptor,
@@ -644,29 +673,11 @@ bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- bool supported = true;
-
- std::array<DataType,4> supportedTypes =
- {
- DataType::Float32,
- DataType::Float16,
- DataType::QuantisedAsymm8,
- DataType::QuantisedSymm16
- };
-
- supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
- "Reference equal: input 0 is not a supported type.");
-
- supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
- "Reference equal: input 1 is not a supported type.");
-
- supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
- "Reference equal: input 0 and Input 1 types are mismatched");
-
- supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
- "Reference equal: shapes are not suitable for implicit broadcast.");
-
- return supported;
+ return IsComparisonSupported(input0,
+ input1,
+ output,
+ ComparisonDescriptor(ComparisonOperation::Equal),
+ reasonIfUnsupported);
}
bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
@@ -802,29 +813,11 @@ bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported) const
{
- bool supported = true;
-
- std::array<DataType,4> supportedTypes =
- {
- DataType::Float32,
- DataType::Float16,
- DataType::QuantisedAsymm8,
- DataType::QuantisedSymm16
- };
-
- supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported,
- "Reference greater: input 0 is not a supported type.");
-
- supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported,
- "Reference greater: input 1 is not a supported type.");
-
- supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported,
- "Reference greater: input 0 and Input 1 types are mismatched");
-
- supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported,
- "Reference greater: shapes are not suitable for implicit broadcast.");
-
- return supported;
+ return IsComparisonSupported(input0,
+ input1,
+ output,
+ ComparisonDescriptor(ComparisonOperation::Greater),
+ reasonIfUnsupported);
}
bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index 5c71e8d337..04b355ee0a 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -45,6 +45,12 @@ public:
const BatchToSpaceNdDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsComparisonSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ const ComparisonDescriptor& descriptor,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
bool IsConcatSupported(const std::vector<const TensorInfo*> inputs,
const TensorInfo& output,
const ConcatDescriptor& descriptor,
@@ -106,6 +112,7 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead")
bool IsEqualSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
@@ -131,6 +138,7 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead")
bool IsGreaterSupported(const TensorInfo& input0,
const TensorInfo& input1,
const TensorInfo& output,
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 1f6d1d7e8b..c2cb51abf3 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -131,6 +131,12 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateBatchToSpaceNd(const BatchT
return std::make_unique<RefBatchToSpaceNdWorkload>(descriptor, info);
}
+std::unique_ptr<IWorkload> RefWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const
+{
+ return std::make_unique<RefComparisonWorkload>(descriptor, info);
+}
+
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
@@ -208,7 +214,10 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDivision(const DivisionQueu
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<RefEqualWorkload>(descriptor, info);
+ ComparisonQueueDescriptor comparisonDescriptor;
+ comparisonDescriptor.m_Parameters.m_Operation = ComparisonOperation::Equal;
+
+ return CreateComparison(comparisonDescriptor, info);
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateFakeQuantization(
@@ -240,7 +249,10 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateGather(const GatherQueueDes
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return std::make_unique<RefGreaterWorkload>(descriptor, info);
+ ComparisonQueueDescriptor comparisonDescriptor;
+ comparisonDescriptor.m_Parameters.m_Operation = ComparisonOperation::Greater;
+
+ return CreateComparison(comparisonDescriptor, info);
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor,
diff --git a/src/backends/reference/RefWorkloadFactory.hpp b/src/backends/reference/RefWorkloadFactory.hpp
index 41e9b28ea2..7b73d5b21f 100644
--- a/src/backends/reference/RefWorkloadFactory.hpp
+++ b/src/backends/reference/RefWorkloadFactory.hpp
@@ -78,6 +78,9 @@ public:
std::unique_ptr<IWorkload> CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
+ std::unique_ptr<IWorkload> CreateComparison(const ComparisonQueueDescriptor& descriptor,
+ const WorkloadInfo& info) const override;
+
std::unique_ptr<IWorkload> CreateConcat(const ConcatQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
@@ -111,6 +114,7 @@ public:
std::unique_ptr<IWorkload> CreateDivision(const DivisionQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
+ ARMNN_DEPRECATED_MSG("Use CreateComparison instead")
std::unique_ptr<IWorkload> CreateEqual(const EqualQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
@@ -126,6 +130,7 @@ public:
std::unique_ptr<IWorkload> CreateGather(const GatherQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
+ ARMNN_DEPRECATED_MSG("Use CreateComparison instead")
std::unique_ptr<IWorkload> CreateGreater(const GreaterQueueDescriptor& descriptor,
const WorkloadInfo& info) const override;
diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk
index 49b07a41d2..7e97acdee2 100644
--- a/src/backends/reference/backend.mk
+++ b/src/backends/reference/backend.mk
@@ -47,6 +47,7 @@ BACKEND_SOURCES := \
workloads/RefArgMinMaxWorkload.cpp \
workloads/RefBatchNormalizationWorkload.cpp \
workloads/RefBatchToSpaceNdWorkload.cpp \
+ workloads/RefComparisonWorkload.cpp \
workloads/RefConcatWorkload.cpp \
workloads/RefConstantWorkload.cpp \
workloads/RefConvertFp16ToFp32Workload.cpp \
diff --git a/src/backends/reference/test/RefEndToEndTests.cpp b/src/backends/reference/test/RefEndToEndTests.cpp
index 370ef6599b..1968e4da7e 100644
--- a/src/backends/reference/test/RefEndToEndTests.cpp
+++ b/src/backends/reference/test/RefEndToEndTests.cpp
@@ -6,8 +6,8 @@
#include <backendsCommon/test/EndToEndTestImpl.hpp>
#include <backendsCommon/test/AbsEndToEndTestImpl.hpp>
-#include <backendsCommon/test/ArithmeticTestImpl.hpp>
#include <backendsCommon/test/BatchToSpaceNdEndToEndTestImpl.hpp>
+#include <backendsCommon/test/ComparisonEndToEndTestImpl.hpp>
#include <backendsCommon/test/ConcatEndToEndTestImpl.hpp>
#include <backendsCommon/test/DepthToSpaceEndToEndTestImpl.hpp>
#include <backendsCommon/test/DequantizeEndToEndTestImpl.hpp>
@@ -348,9 +348,9 @@ BOOST_AUTO_TEST_CASE(RefEqualSimpleEndToEndTest)
const std::vector<uint8_t> expectedOutput({ 1, 1, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1 });
- ArithmeticSimpleEndToEnd<armnn::DataType::Float32, armnn::DataType::Boolean>(defaultBackends,
- LayerType::Equal,
- expectedOutput);
+ ComparisonSimpleEndToEnd<armnn::DataType::Float32>(defaultBackends,
+ ComparisonOperation::Equal,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(RefGreaterSimpleEndToEndTest)
@@ -358,9 +358,9 @@ BOOST_AUTO_TEST_CASE(RefGreaterSimpleEndToEndTest)
const std::vector<uint8_t> expectedOutput({ 0, 0, 0, 0, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0 });
- ArithmeticSimpleEndToEnd<armnn::DataType::Float32, armnn::DataType::Boolean>(defaultBackends,
- LayerType::Greater,
- expectedOutput);
+ ComparisonSimpleEndToEnd<armnn::DataType::Float32>(defaultBackends,
+ ComparisonOperation::Greater,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(RefEqualSimpleEndToEndUint8Test)
@@ -368,9 +368,9 @@ BOOST_AUTO_TEST_CASE(RefEqualSimpleEndToEndUint8Test)
const std::vector<uint8_t> expectedOutput({ 1, 1, 1, 1, 0, 0, 0, 0,
0, 0, 0, 0, 1, 1, 1, 1 });
- ArithmeticSimpleEndToEnd<armnn::DataType::QuantisedAsymm8, armnn::DataType::Boolean>(defaultBackends,
- LayerType::Equal,
- expectedOutput);
+ ComparisonSimpleEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends,
+ ComparisonOperation::Equal,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(RefGreaterSimpleEndToEndUint8Test)
@@ -378,9 +378,9 @@ BOOST_AUTO_TEST_CASE(RefGreaterSimpleEndToEndUint8Test)
const std::vector<uint8_t> expectedOutput({ 0, 0, 0, 0, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0 });
- ArithmeticSimpleEndToEnd<armnn::DataType::QuantisedAsymm8, armnn::DataType::Boolean>(defaultBackends,
- LayerType::Greater,
- expectedOutput);
+ ComparisonSimpleEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends,
+ ComparisonOperation::Greater,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(RefEqualBroadcastEndToEndTest)
@@ -388,9 +388,9 @@ BOOST_AUTO_TEST_CASE(RefEqualBroadcastEndToEndTest)
const std::vector<uint8_t> expectedOutput({ 1, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0 });
- ArithmeticBroadcastEndToEnd<armnn::DataType::Float32, armnn::DataType::Boolean>(defaultBackends,
- LayerType::Equal,
- expectedOutput);
+ ComparisonBroadcastEndToEnd<armnn::DataType::Float32>(defaultBackends,
+ ComparisonOperation::Equal,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(RefGreaterBroadcastEndToEndTest)
@@ -398,9 +398,9 @@ BOOST_AUTO_TEST_CASE(RefGreaterBroadcastEndToEndTest)
const std::vector<uint8_t> expectedOutput({ 0, 1, 0, 0, 0, 1,
1, 1, 1, 1, 1, 1 });
- ArithmeticBroadcastEndToEnd<armnn::DataType::Float32, armnn::DataType::Boolean>(defaultBackends,
- LayerType::Greater,
- expectedOutput);
+ ComparisonBroadcastEndToEnd<armnn::DataType::Float32>(defaultBackends,
+ ComparisonOperation::Greater,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(RefEqualBroadcastEndToEndUint8Test)
@@ -408,9 +408,9 @@ BOOST_AUTO_TEST_CASE(RefEqualBroadcastEndToEndUint8Test)
const std::vector<uint8_t > expectedOutput({ 1, 0, 1, 1, 0, 0,
0, 0, 0, 0, 0, 0 });
- ArithmeticBroadcastEndToEnd<armnn::DataType::QuantisedAsymm8, armnn::DataType::Boolean>(defaultBackends,
- LayerType::Equal,
- expectedOutput);
+ ComparisonBroadcastEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends,
+ ComparisonOperation::Equal,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(RefGreaterBroadcastEndToEndUint8Test)
@@ -418,9 +418,9 @@ BOOST_AUTO_TEST_CASE(RefGreaterBroadcastEndToEndUint8Test)
const std::vector<uint8_t> expectedOutput({ 0, 1, 0, 0, 0, 1,
1, 1, 1, 1, 1, 1 });
- ArithmeticBroadcastEndToEnd<armnn::DataType::QuantisedAsymm8, armnn::DataType::Boolean>(defaultBackends,
- LayerType::Greater,
- expectedOutput);
+ ComparisonBroadcastEndToEnd<armnn::DataType::QuantisedAsymm8>(defaultBackends,
+ ComparisonOperation::Greater,
+ expectedOutput);
}
BOOST_AUTO_TEST_CASE(RefBatchToSpaceNdEndToEndFloat32NHWCTest)
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index b8eb95c729..7844518620 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -63,6 +63,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefBatchNormalizationWorkload.hpp
RefBatchToSpaceNdWorkload.cpp
RefBatchToSpaceNdWorkload.hpp
+ RefComparisonWorkload.cpp
+ RefComparisonWorkload.hpp
RefConcatWorkload.cpp
RefConcatWorkload.hpp
RefConstantWorkload.cpp
diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp
index 7a5c071f70..888037f9a6 100644
--- a/src/backends/reference/workloads/ElementwiseFunction.cpp
+++ b/src/backends/reference/workloads/ElementwiseFunction.cpp
@@ -32,6 +32,11 @@ template struct armnn::ElementwiseFunction<std::multiplies<float>>;
template struct armnn::ElementwiseFunction<std::divides<float>>;
template struct armnn::ElementwiseFunction<armnn::maximum<float>>;
template struct armnn::ElementwiseFunction<armnn::minimum<float>>;
+
+// Comparison
template struct armnn::ElementwiseFunction<std::equal_to<float>>;
template struct armnn::ElementwiseFunction<std::greater<float>>;
-
+template struct armnn::ElementwiseFunction<std::greater_equal<float>>;
+template struct armnn::ElementwiseFunction<std::less<float>>;
+template struct armnn::ElementwiseFunction<std::less_equal<float>>;
+template struct armnn::ElementwiseFunction<std::not_equal_to<float>>;
diff --git a/src/backends/reference/workloads/RefComparisonWorkload.cpp b/src/backends/reference/workloads/RefComparisonWorkload.cpp
new file mode 100644
index 0000000000..60446226be
--- /dev/null
+++ b/src/backends/reference/workloads/RefComparisonWorkload.cpp
@@ -0,0 +1,102 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefComparisonWorkload.hpp"
+
+#include "Decoders.hpp"
+#include "ElementwiseFunction.hpp"
+#include "Encoders.hpp"
+#include "RefWorkloadUtils.hpp"
+
+#include <Profiling.hpp>
+
+#include <armnn/TypesUtils.hpp>
+
+#include <functional>
+
+namespace armnn
+{
+
+RefComparisonWorkload::RefComparisonWorkload(const ComparisonQueueDescriptor& desc,
+ const WorkloadInfo& info)
+ : BaseWorkload<ComparisonQueueDescriptor>(desc, info)
+{}
+
+void RefComparisonWorkload::PostAllocationConfigure()
+{
+ const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ m_Input0 = MakeDecoder<InType>(inputInfo0);
+ m_Input1 = MakeDecoder<InType>(inputInfo1);
+
+ m_Output = MakeEncoder<OutType>(outputInfo);
+}
+
+void RefComparisonWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefComparisonWorkload_Execute");
+
+ const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ const TensorShape& inShape0 = inputInfo0.GetShape();
+ const TensorShape& inShape1 = inputInfo1.GetShape();
+ const TensorShape& outShape = outputInfo.GetShape();
+
+ m_Input0->Reset(m_Data.m_Inputs[0]->Map());
+ m_Input1->Reset(m_Data.m_Inputs[1]->Map());
+ m_Output->Reset(m_Data.m_Outputs[0]->Map());
+
+ using EqualFunction = ElementwiseFunction<std::equal_to<InType>>;
+ using GreaterFunction = ElementwiseFunction<std::greater<InType>>;
+ using GreaterOrEqualFunction = ElementwiseFunction<std::greater_equal<InType>>;
+ using LessFunction = ElementwiseFunction<std::less<InType>>;
+ using LessOrEqualFunction = ElementwiseFunction<std::less_equal<InType>>;
+ using NotEqualFunction = ElementwiseFunction<std::not_equal_to<InType>>;
+
+ switch (m_Data.m_Parameters.m_Operation)
+ {
+ case ComparisonOperation::Equal:
+ {
+ EqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
+ break;
+ }
+ case ComparisonOperation::Greater:
+ {
+ GreaterFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
+ break;
+ }
+ case ComparisonOperation::GreaterOrEqual:
+ {
+ GreaterOrEqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
+ break;
+ }
+ case ComparisonOperation::Less:
+ {
+ LessFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
+ break;
+ }
+ case ComparisonOperation::LessOrEqual:
+ {
+ LessOrEqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
+ break;
+ }
+ case ComparisonOperation::NotEqual:
+ {
+ NotEqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
+ break;
+ }
+ default:
+ {
+ throw InvalidArgumentException(std::string("Unsupported comparison operation ") +
+ GetComparisonOperationAsCString(m_Data.m_Parameters.m_Operation), CHECK_LOCATION());
+ }
+ }
+}
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefComparisonWorkload.hpp b/src/backends/reference/workloads/RefComparisonWorkload.hpp
new file mode 100644
index 0000000000..a19e4a0540
--- /dev/null
+++ b/src/backends/reference/workloads/RefComparisonWorkload.hpp
@@ -0,0 +1,34 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "BaseIterator.hpp"
+
+#include <backendsCommon/Workload.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+
+namespace armnn
+{
+
+class RefComparisonWorkload : public BaseWorkload<ComparisonQueueDescriptor>
+{
+public:
+ using BaseWorkload<ComparisonQueueDescriptor>::m_Data;
+
+ RefComparisonWorkload(const ComparisonQueueDescriptor& descriptor, const WorkloadInfo& info);
+ void PostAllocationConfigure() override;
+ void Execute() const override;
+
+private:
+ using InType = float;
+ using OutType = bool;
+
+ std::unique_ptr<Decoder<InType>> m_Input0;
+ std::unique_ptr<Decoder<InType>> m_Input1;
+ std::unique_ptr<Encoder<OutType>> m_Output;
+};
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
index 6431348bc2..7e02f032ef 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
@@ -86,11 +86,3 @@ template class armnn::RefElementwiseWorkload<armnn::maximum<float>,
template class armnn::RefElementwiseWorkload<armnn::minimum<float>,
armnn::MinimumQueueDescriptor,
armnn::StringMapping::RefMinimumWorkload_Execute>;
-
-template class armnn::RefElementwiseWorkload<std::equal_to<float>,
- armnn::EqualQueueDescriptor,
- armnn::StringMapping::RefEqualWorkload_Execute>;
-
-template class armnn::RefElementwiseWorkload<std::greater<float>,
- armnn::GreaterQueueDescriptor,
- armnn::StringMapping::RefGreaterWorkload_Execute>;
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
index 651942e9e5..ee0d80b172 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
@@ -65,13 +65,4 @@ using RefMinimumWorkload =
MinimumQueueDescriptor,
StringMapping::RefMinimumWorkload_Execute>;
-using RefEqualWorkload =
- RefElementwiseWorkload<std::equal_to<float>,
- armnn::EqualQueueDescriptor,
- armnn::StringMapping::RefEqualWorkload_Execute>;
-
-using RefGreaterWorkload =
- RefElementwiseWorkload<std::greater<float>,
- armnn::GreaterQueueDescriptor,
- armnn::StringMapping::RefGreaterWorkload_Execute>;
} // armnn
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 79d1935823..1f9ad4a19a 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -20,6 +20,7 @@
#include "RefArgMinMaxWorkload.hpp"
#include "RefBatchNormalizationWorkload.hpp"
#include "RefBatchToSpaceNdWorkload.hpp"
+#include "RefComparisonWorkload.hpp"
#include "RefConvolution2dWorkload.hpp"
#include "RefConstantWorkload.hpp"
#include "RefConcatWorkload.hpp"
diff --git a/src/backends/reference/workloads/StringMapping.hpp b/src/backends/reference/workloads/StringMapping.hpp
index 073a5a6833..1654b78088 100644
--- a/src/backends/reference/workloads/StringMapping.hpp
+++ b/src/backends/reference/workloads/StringMapping.hpp
@@ -18,9 +18,7 @@ struct StringMapping
public:
enum Id {
RefAdditionWorkload_Execute,
- RefEqualWorkload_Execute,
RefDivisionWorkload_Execute,
- RefGreaterWorkload_Execute,
RefMaximumWorkload_Execute,
RefMinimumWorkload_Execute,
RefMultiplicationWorkload_Execute,
@@ -40,8 +38,6 @@ private:
{
m_Strings[RefAdditionWorkload_Execute] = "RefAdditionWorkload_Execute";
m_Strings[RefDivisionWorkload_Execute] = "RefDivisionWorkload_Execute";
- m_Strings[RefEqualWorkload_Execute] = "RefEqualWorkload_Execute";
- m_Strings[RefGreaterWorkload_Execute] = "RefGreaterWorkload_Execute";
m_Strings[RefMaximumWorkload_Execute] = "RefMaximumWorkload_Execute";
m_Strings[RefMinimumWorkload_Execute] = "RefMinimumWorkload_Execute";
m_Strings[RefMultiplicationWorkload_Execute] = "RefMultiplicationWorkload_Execute";