aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorFrancisMurtagh <francis.murtagh@arm.com>2018-12-18 12:57:35 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2018-12-18 14:33:21 +0000
commit30cdfcac03fc9f3ab424865b40c0490799e5c8fb (patch)
tree4611a128b99b60387ce84c463346a275c7266c3a /src
parentd74dc91af2d1302bf9024fcb1690d5df035f9c15 (diff)
downloadarmnn-30cdfcac03fc9f3ab424865b40c0490799e5c8fb.tar.gz
IVGCVSW-2365 Add Reference Equal Workload Implementation
* Add reference equal workload * Add Reference Workload Unit Test Change-Id: If2848e7dde4248566b99d91726d08143c40ff80d
Diffstat (limited to 'src')
-rw-r--r--src/backends/backendsCommon/StringMapping.hpp2
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp13
-rwxr-xr-xsrc/backends/backendsCommon/test/LayerTests.cpp172
-rw-r--r--src/backends/backendsCommon/test/LayerTests.hpp24
-rw-r--r--src/backends/cl/ClLayerSupport.cpp12
-rw-r--r--src/backends/cl/ClLayerSupport.hpp5
-rw-r--r--src/backends/neon/NeonLayerSupport.cpp12
-rw-r--r--src/backends/neon/NeonLayerSupport.hpp5
-rw-r--r--src/backends/reference/RefLayerSupport.cpp15
-rw-r--r--src/backends/reference/RefLayerSupport.hpp5
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp2
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp8
-rw-r--r--src/backends/reference/workloads/ElementwiseFunction.cpp13
-rw-r--r--src/backends/reference/workloads/ElementwiseFunction.hpp10
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.cpp3
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.hpp12
16 files changed, 301 insertions, 12 deletions
diff --git a/src/backends/backendsCommon/StringMapping.hpp b/src/backends/backendsCommon/StringMapping.hpp
index aa7fb6df61..8541195356 100644
--- a/src/backends/backendsCommon/StringMapping.hpp
+++ b/src/backends/backendsCommon/StringMapping.hpp
@@ -18,6 +18,7 @@ struct StringMapping
public:
enum Id {
RefAdditionWorkload_Execute,
+ RefEqualWorkload_Execute,
RefSubtractionWorkload_Execute,
RefMaximumWorkload_Execute,
RefMultiplicationWorkload_Execute,
@@ -37,6 +38,7 @@ private:
StringMapping()
{
m_Strings[RefAdditionWorkload_Execute] = "RefAdditionWorkload_Execute";
+ m_Strings[RefEqualWorkload_Execute] = "RefEqualWorkload_Execute";
m_Strings[RefSubtractionWorkload_Execute] = "RefSubtractionWorkload_Execute";
m_Strings[RefMaximumWorkload_Execute] = "RefMaximumWorkload_Execute";
m_Strings[RefMultiplicationWorkload_Execute] = "RefMultiplicationWorkload_Execute";
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 47d5364597..67cee1cc0d 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1013,4 +1013,17 @@ void DebugQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
ValidateSingleOutput(workloadInfo, "DebugQueueDescriptor");
}
+void EqualQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
+{
+ ValidateTwoInputs(workloadInfo, "EqualQueueDescriptor");
+ ValidateSingleOutput(workloadInfo, "EqualQueueDescriptor");
+
+ ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
+ workloadInfo.m_InputTensorInfos[1],
+ workloadInfo.m_OutputTensorInfos[0],
+ "EqualQueueDescriptor",
+ "first input",
+ "second input");
+}
+
} //namespace armnn
diff --git a/src/backends/backendsCommon/test/LayerTests.cpp b/src/backends/backendsCommon/test/LayerTests.cpp
index b44c835cb2..4dc49f97a2 100755
--- a/src/backends/backendsCommon/test/LayerTests.cpp
+++ b/src/backends/backendsCommon/test/LayerTests.cpp
@@ -1665,6 +1665,15 @@ std::unique_ptr<armnn::IWorkload> CreateWorkload<armnn::MinimumQueueDescriptor>(
return workloadFactory.CreateMinimum(descriptor, info);
}
+template<>
+std::unique_ptr<armnn::IWorkload> CreateWorkload<armnn::EqualQueueDescriptor>(
+ const armnn::IWorkloadFactory& workloadFactory,
+ const armnn::WorkloadInfo& info,
+ const armnn::EqualQueueDescriptor& descriptor)
+{
+ return workloadFactory.CreateEqual(descriptor, info);
+}
+
namespace {
template <typename Descriptor, typename dataType>
LayerTestResult<dataType, 4> ElementwiseTestHelper
@@ -1724,6 +1733,169 @@ namespace {
}
}
+LayerTestResult<float, 4> EqualSimpleTest(armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+ const unsigned int width = 2;
+ const unsigned int height = 2;
+ const unsigned int channelCount = 2;
+ const unsigned int batchSize = 2;
+
+ unsigned int shape[] = { batchSize, channelCount, height, width };
+
+ std::vector<float> input0({ 1, 1, 1, 1, 5, 5, 5, 5,
+ 3, 3, 3, 3, 4, 4, 4, 4 });
+
+ std::vector<float> input1({ 1, 1, 1, 1, 3, 3, 3, 3,
+ 5, 5, 5, 5, 4, 4, 4, 4 });
+
+ std::vector<float> output({ 1, 1, 1, 1, 0, 0, 0, 0,
+ 0, 0, 0, 0, 1, 1, 1, 1 });
+
+ return ElementwiseTestHelper<armnn::EqualQueueDescriptor, float>
+ (workloadFactory,
+ memoryManager,
+ shape,
+ input0,
+ shape,
+ input1,
+ shape,
+ output);
+}
+
+LayerTestResult<float, 4> EqualBroadcast1ElementTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+ unsigned int shape0[] = { 1, 2, 2, 2 };
+ std::vector<float> input0({ 1, 2, 3, 4, 5, 6, 7, 8});
+
+ unsigned int shape1[] = { 1, 1, 1, 1 };
+ std::vector<float> input1({ 1 });
+
+ std::vector<float> output({ 1, 0, 0, 0, 0, 0, 0, 0});
+
+ return ElementwiseTestHelper<armnn::EqualQueueDescriptor, float>
+ (workloadFactory,
+ memoryManager,
+ shape0,
+ input0,
+ shape1,
+ input1,
+ shape0,
+ output);
+}
+
+LayerTestResult<float, 4> EqualBroadcast1DVectorTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+ const unsigned int shape0[] = { 1, 2, 2, 3 };
+ const unsigned int shape1[] = { 1, 1, 1, 3 };
+
+ std::vector<float> input0({ 1, 2, 3, 4, 5, 6,
+ 7, 8, 9, 10, 11, 12 });
+
+ std::vector<float> input1({ 1, 2, 3});
+
+ std::vector<float> output({ 1, 1, 1, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0 });
+
+ return ElementwiseTestHelper<armnn::EqualQueueDescriptor, float>
+ (workloadFactory,
+ memoryManager,
+ shape0,
+ input0,
+ shape1,
+ input1,
+ shape0,
+ output);
+}
+
+LayerTestResult<uint8_t, 4> EqualUint8Test(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+ unsigned int shape[] = { 2, 2, 2, 2 };
+
+ // See dequantized values to the right.
+ std::vector<uint8_t> input0({ 1, 1, 1, 1, 6, 6, 6, 6,
+ 3, 3, 3, 3, 5, 5, 5, 5 });
+
+ std::vector<uint8_t> input1({ 2, 2, 2, 2, 6, 6, 6, 6,
+ 3, 3, 3, 3, 5, 5, 5, 5 });
+
+ std::vector<uint8_t> output({ 0, 0, 0, 0, 1, 1, 1, 1,
+ 1, 1, 1, 1, 0, 0, 0, 0 });
+
+ return ElementwiseTestHelper<armnn::EqualQueueDescriptor, uint8_t >
+ (workloadFactory,
+ memoryManager,
+ shape,
+ input0,
+ shape,
+ input1,
+ shape,
+ output,
+ 1.0f,
+ 0);
+}
+
+LayerTestResult<uint8_t, 4> EqualBroadcast1ElementUint8Test(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+ const unsigned int shape0[] = { 1, 2, 2, 3 };
+ const unsigned int shape1[] = { 1, 1, 1, 1 };
+
+ std::vector<uint8_t> input0({ 1, 2, 3, 4, 5, 6,
+ 7, 8, 9, 10, 11, 12 });
+
+ std::vector<uint8_t> input1({ 1 });
+
+ std::vector<uint8_t> output({ 1, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0 });
+
+ return ElementwiseTestHelper<armnn::EqualQueueDescriptor, uint8_t >
+ (workloadFactory,
+ memoryManager,
+ shape0,
+ input0,
+ shape1,
+ input1,
+ shape0,
+ output,
+ 1.0f,
+ 0);
+}
+
+LayerTestResult<uint8_t, 4> EqualBroadcast1DVectorUint8Test(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
+{
+ const unsigned int shape0[] = { 1, 2, 2, 3 };
+ const unsigned int shape1[] = { 1, 1, 1, 3 };
+
+ std::vector<uint8_t> input0({ 1, 2, 3, 4, 5, 6,
+ 7, 8, 9, 10, 11, 12 });
+
+ std::vector<uint8_t> input1({ 1, 1, 3});
+
+ std::vector<uint8_t> output({ 1, 0, 1, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0 });
+
+ return ElementwiseTestHelper<armnn::EqualQueueDescriptor, uint8_t>
+ (workloadFactory,
+ memoryManager,
+ shape0,
+ input0,
+ shape1,
+ input1,
+ shape0,
+ output,
+ 1.0f,
+ 0);
+}
LayerTestResult<float, 4> MaximumSimpleTest(armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
diff --git a/src/backends/backendsCommon/test/LayerTests.hpp b/src/backends/backendsCommon/test/LayerTests.hpp
index 1f38675b37..029418e850 100644
--- a/src/backends/backendsCommon/test/LayerTests.hpp
+++ b/src/backends/backendsCommon/test/LayerTests.hpp
@@ -867,6 +867,30 @@ LayerTestResult<uint8_t, 3> Concatenation3dDim2DiffInputDimsUint8Test(
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
bool useSubtensor);
+LayerTestResult<float, 4> EqualSimpleTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+
+LayerTestResult<float, 4> EqualBroadcast1ElementTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+
+LayerTestResult<float, 4> EqualBroadcast1DVectorTest(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+
+LayerTestResult<uint8_t, 4> EqualUint8Test(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+
+LayerTestResult<uint8_t, 4> EqualBroadcast1ElementUint8Test(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+
+LayerTestResult<uint8_t, 4> EqualBroadcast1DVectorUint8Test(
+ armnn::IWorkloadFactory& workloadFactory,
+ const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager);
+
LayerTestResult<float, 2> FullyConnectedLargeTest(
armnn::IWorkloadFactory& workloadFactory,
const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp
index eebab352e1..82b46476ea 100644
--- a/src/backends/cl/ClLayerSupport.cpp
+++ b/src/backends/cl/ClLayerSupport.cpp
@@ -252,6 +252,18 @@ bool ClLayerSupport::IsDivisionSupported(const TensorInfo& input0,
output);
}
+bool ClLayerSupport::IsEqualSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ ignore_unused(input0);
+ ignore_unused(input1);
+ ignore_unused(output);
+ ignore_unused(reasonIfUnsupported);
+ return false;
+}
+
bool ClLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
const FakeQuantizationDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const
diff --git a/src/backends/cl/ClLayerSupport.hpp b/src/backends/cl/ClLayerSupport.hpp
index 470fa2acb4..82efd0016a 100644
--- a/src/backends/cl/ClLayerSupport.hpp
+++ b/src/backends/cl/ClLayerSupport.hpp
@@ -66,6 +66,11 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsEqualSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
bool IsFakeQuantizationSupported(const TensorInfo& input,
const FakeQuantizationDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
diff --git a/src/backends/neon/NeonLayerSupport.cpp b/src/backends/neon/NeonLayerSupport.cpp
index 36c9f8bc08..0033b86917 100644
--- a/src/backends/neon/NeonLayerSupport.cpp
+++ b/src/backends/neon/NeonLayerSupport.cpp
@@ -210,6 +210,18 @@ bool NeonLayerSupport::IsDivisionSupported(const TensorInfo& input0,
return false;
}
+bool NeonLayerSupport::IsEqualSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ ignore_unused(input0);
+ ignore_unused(input1);
+ ignore_unused(output);
+ ignore_unused(reasonIfUnsupported);
+ return false;
+}
+
bool NeonLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
const FakeQuantizationDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const
diff --git a/src/backends/neon/NeonLayerSupport.hpp b/src/backends/neon/NeonLayerSupport.hpp
index e5cd3cc062..5724ed85df 100644
--- a/src/backends/neon/NeonLayerSupport.hpp
+++ b/src/backends/neon/NeonLayerSupport.hpp
@@ -61,6 +61,11 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsEqualSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
bool IsFakeQuantizationSupported(const TensorInfo& input,
const FakeQuantizationDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 2c8f9cb6e1..2952ae1a80 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -203,6 +203,21 @@ bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
&TrueFunc<>);
}
+bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ ignore_unused(input0);
+ ignore_unused(input1);
+ ignore_unused(output);
+ ignore_unused(reasonIfUnsupported);
+ return IsSupportedForDataTypeRef(reasonIfUnsupported,
+ input0.GetDataType(),
+ &TrueFunc<>,
+ &TrueFunc<>);
+}
+
bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
const FakeQuantizationDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index 9dc64cb37c..399f7b5699 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -71,6 +71,11 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsEqualSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
bool IsFakeQuantizationSupported(const TensorInfo& input,
const FakeQuantizationDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 110a947491..8173bbb952 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -285,7 +285,7 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePad(const PadQueueDescripto
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+ return MakeWorkload<RefEqualFloat32Workload, RefEqualUint8Workload>(descriptor, info);
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index d3c2231a23..eda58a99b1 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -233,6 +233,14 @@ ARMNN_AUTO_TEST_CASE(DivisionUint8, DivisionUint8Test)
ARMNN_AUTO_TEST_CASE(DivisionUint8Broadcast1Element, DivisionBroadcast1ElementUint8Test)
ARMNN_AUTO_TEST_CASE(DivisionUint8Broadcast1DVector, DivisionBroadcast1DVectorUint8Test)
+// Equal
+ARMNN_AUTO_TEST_CASE(SimpleEqual, EqualSimpleTest)
+ARMNN_AUTO_TEST_CASE(EqualBroadcast1Element, EqualBroadcast1ElementTest)
+ARMNN_AUTO_TEST_CASE(EqualBroadcast1DVector, EqualBroadcast1DVectorTest)
+ARMNN_AUTO_TEST_CASE(EqualUint8, EqualUint8Test)
+ARMNN_AUTO_TEST_CASE(EqualBroadcast1ElementUint8, EqualBroadcast1ElementUint8Test)
+ARMNN_AUTO_TEST_CASE(EqualBroadcast1DVectorUint8, EqualBroadcast1DVectorUint8Test)
+
// Max
ARMNN_AUTO_TEST_CASE(SimpleMaximum, MaximumSimpleTest)
ARMNN_AUTO_TEST_CASE(MaximumBroadcast1Element, MaximumBroadcast1ElementTest)
diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp
index 88d51908fe..18ceade113 100644
--- a/src/backends/reference/workloads/ElementwiseFunction.cpp
+++ b/src/backends/reference/workloads/ElementwiseFunction.cpp
@@ -15,11 +15,11 @@ namespace armnn
template <typename Functor>
ElementwiseFunction<Functor>::ElementwiseFunction(const TensorShape& inShape0,
- const TensorShape& inShape1,
- const TensorShape& outShape,
- const float* inData0,
- const float* inData1,
- float* outData)
+ const TensorShape& inShape1,
+ const TensorShape& outShape,
+ const float* inData0,
+ const float* inData1,
+ float* outData)
{
BroadcastLoop(inShape0, inShape1, outShape).Unroll(Functor(), 0, inData0, inData1, outData);
}
@@ -31,4 +31,5 @@ template struct armnn::ElementwiseFunction<std::minus<float>>;
template struct armnn::ElementwiseFunction<std::multiplies<float>>;
template struct armnn::ElementwiseFunction<std::divides<float>>;
template struct armnn::ElementwiseFunction<armnn::maximum<float>>;
-template struct armnn::ElementwiseFunction<armnn::minimum<float>>; \ No newline at end of file
+template struct armnn::ElementwiseFunction<armnn::minimum<float>>;
+template struct armnn::ElementwiseFunction<std::equal_to<float>>; \ No newline at end of file
diff --git a/src/backends/reference/workloads/ElementwiseFunction.hpp b/src/backends/reference/workloads/ElementwiseFunction.hpp
index 5011616c0c..0ac136466c 100644
--- a/src/backends/reference/workloads/ElementwiseFunction.hpp
+++ b/src/backends/reference/workloads/ElementwiseFunction.hpp
@@ -14,11 +14,11 @@ template <typename Functor>
struct ElementwiseFunction
{
ElementwiseFunction(const TensorShape& inShape0,
- const TensorShape& inShape1,
- const TensorShape& outShape,
- const float* inData0,
- const float* inData1,
- float* outData);
+ const TensorShape& inShape1,
+ const TensorShape& outShape,
+ const float* inData0,
+ const float* inData1,
+ float* outData);
};
} //namespace armnn
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
index a18c7c569e..d00bfd01b4 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
@@ -73,3 +73,6 @@ template class armnn::BaseUint8ElementwiseWorkload<armnn::MaximumQueueDescriptor
template class armnn::BaseFloat32ElementwiseWorkload<armnn::MinimumQueueDescriptor, armnn::minimum<float>>;
template class armnn::BaseUint8ElementwiseWorkload<armnn::MinimumQueueDescriptor, armnn::minimum<float>>;
+
+template class armnn::BaseFloat32ElementwiseWorkload<armnn::EqualQueueDescriptor, std::equal_to<float>>;
+template class armnn::BaseUint8ElementwiseWorkload<armnn::EqualQueueDescriptor, std::equal_to<float>>;
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
index b5205938b2..c2855b0550 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
@@ -147,4 +147,16 @@ using RefMinimumUint8Workload =
DataType::QuantisedAsymm8,
MinimumQueueDescriptor,
StringMapping::RefMinimumWorkload_Execute>;
+
+using RefEqualFloat32Workload =
+ RefElementwiseWorkload<std::equal_to<float>,
+ DataType::Float32,
+ EqualQueueDescriptor,
+ StringMapping::RefEqualWorkload_Execute>;
+
+using RefEqualUint8Workload =
+ RefElementwiseWorkload<std::equal_to<float>,
+ DataType::QuantisedAsymm8,
+ EqualQueueDescriptor,
+ StringMapping::RefEqualWorkload_Execute>;
} // armnn