aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference
diff options
context:
space:
mode:
authorFrancisMurtagh <francis.murtagh@arm.com>2018-12-18 12:57:35 +0000
committerMatteo Martincigh <matteo.martincigh@arm.com>2018-12-18 14:33:21 +0000
commit30cdfcac03fc9f3ab424865b40c0490799e5c8fb (patch)
tree4611a128b99b60387ce84c463346a275c7266c3a /src/backends/reference
parentd74dc91af2d1302bf9024fcb1690d5df035f9c15 (diff)
downloadarmnn-30cdfcac03fc9f3ab424865b40c0490799e5c8fb.tar.gz
IVGCVSW-2365 Add Reference Equal Workload Implementation
* Add reference equal workload * Add Reference Workload Unit Test Change-Id: If2848e7dde4248566b99d91726d08143c40ff80d
Diffstat (limited to 'src/backends/reference')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp15
-rw-r--r--src/backends/reference/RefLayerSupport.hpp5
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp2
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp8
-rw-r--r--src/backends/reference/workloads/ElementwiseFunction.cpp13
-rw-r--r--src/backends/reference/workloads/ElementwiseFunction.hpp10
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.cpp3
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.hpp12
8 files changed, 56 insertions, 12 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 2c8f9cb6e1..2952ae1a80 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -203,6 +203,21 @@ bool RefLayerSupport::IsDivisionSupported(const TensorInfo& input0,
&TrueFunc<>);
}
+bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ ignore_unused(input0);
+ ignore_unused(input1);
+ ignore_unused(output);
+ ignore_unused(reasonIfUnsupported);
+ return IsSupportedForDataTypeRef(reasonIfUnsupported,
+ input0.GetDataType(),
+ &TrueFunc<>,
+ &TrueFunc<>);
+}
+
bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input,
const FakeQuantizationDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported) const
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index 9dc64cb37c..399f7b5699 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -71,6 +71,11 @@ public:
const TensorInfo& output,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsEqualSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
bool IsFakeQuantizationSupported(const TensorInfo& input,
const FakeQuantizationDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 110a947491..8173bbb952 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -285,7 +285,7 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePad(const PadQueueDescripto
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+ return MakeWorkload<RefEqualFloat32Workload, RefEqualUint8Workload>(descriptor, info);
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index d3c2231a23..eda58a99b1 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -233,6 +233,14 @@ ARMNN_AUTO_TEST_CASE(DivisionUint8, DivisionUint8Test)
ARMNN_AUTO_TEST_CASE(DivisionUint8Broadcast1Element, DivisionBroadcast1ElementUint8Test)
ARMNN_AUTO_TEST_CASE(DivisionUint8Broadcast1DVector, DivisionBroadcast1DVectorUint8Test)
+// Equal
+ARMNN_AUTO_TEST_CASE(SimpleEqual, EqualSimpleTest)
+ARMNN_AUTO_TEST_CASE(EqualBroadcast1Element, EqualBroadcast1ElementTest)
+ARMNN_AUTO_TEST_CASE(EqualBroadcast1DVector, EqualBroadcast1DVectorTest)
+ARMNN_AUTO_TEST_CASE(EqualUint8, EqualUint8Test)
+ARMNN_AUTO_TEST_CASE(EqualBroadcast1ElementUint8, EqualBroadcast1ElementUint8Test)
+ARMNN_AUTO_TEST_CASE(EqualBroadcast1DVectorUint8, EqualBroadcast1DVectorUint8Test)
+
// Max
ARMNN_AUTO_TEST_CASE(SimpleMaximum, MaximumSimpleTest)
ARMNN_AUTO_TEST_CASE(MaximumBroadcast1Element, MaximumBroadcast1ElementTest)
diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp
index 88d51908fe..18ceade113 100644
--- a/src/backends/reference/workloads/ElementwiseFunction.cpp
+++ b/src/backends/reference/workloads/ElementwiseFunction.cpp
@@ -15,11 +15,11 @@ namespace armnn
template <typename Functor>
ElementwiseFunction<Functor>::ElementwiseFunction(const TensorShape& inShape0,
- const TensorShape& inShape1,
- const TensorShape& outShape,
- const float* inData0,
- const float* inData1,
- float* outData)
+ const TensorShape& inShape1,
+ const TensorShape& outShape,
+ const float* inData0,
+ const float* inData1,
+ float* outData)
{
BroadcastLoop(inShape0, inShape1, outShape).Unroll(Functor(), 0, inData0, inData1, outData);
}
@@ -31,4 +31,5 @@ template struct armnn::ElementwiseFunction<std::minus<float>>;
template struct armnn::ElementwiseFunction<std::multiplies<float>>;
template struct armnn::ElementwiseFunction<std::divides<float>>;
template struct armnn::ElementwiseFunction<armnn::maximum<float>>;
-template struct armnn::ElementwiseFunction<armnn::minimum<float>>; \ No newline at end of file
+template struct armnn::ElementwiseFunction<armnn::minimum<float>>;
+template struct armnn::ElementwiseFunction<std::equal_to<float>>; \ No newline at end of file
diff --git a/src/backends/reference/workloads/ElementwiseFunction.hpp b/src/backends/reference/workloads/ElementwiseFunction.hpp
index 5011616c0c..0ac136466c 100644
--- a/src/backends/reference/workloads/ElementwiseFunction.hpp
+++ b/src/backends/reference/workloads/ElementwiseFunction.hpp
@@ -14,11 +14,11 @@ template <typename Functor>
struct ElementwiseFunction
{
ElementwiseFunction(const TensorShape& inShape0,
- const TensorShape& inShape1,
- const TensorShape& outShape,
- const float* inData0,
- const float* inData1,
- float* outData);
+ const TensorShape& inShape1,
+ const TensorShape& outShape,
+ const float* inData0,
+ const float* inData1,
+ float* outData);
};
} //namespace armnn
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
index a18c7c569e..d00bfd01b4 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
@@ -73,3 +73,6 @@ template class armnn::BaseUint8ElementwiseWorkload<armnn::MaximumQueueDescriptor
template class armnn::BaseFloat32ElementwiseWorkload<armnn::MinimumQueueDescriptor, armnn::minimum<float>>;
template class armnn::BaseUint8ElementwiseWorkload<armnn::MinimumQueueDescriptor, armnn::minimum<float>>;
+
+template class armnn::BaseFloat32ElementwiseWorkload<armnn::EqualQueueDescriptor, std::equal_to<float>>;
+template class armnn::BaseUint8ElementwiseWorkload<armnn::EqualQueueDescriptor, std::equal_to<float>>;
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
index b5205938b2..c2855b0550 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
@@ -147,4 +147,16 @@ using RefMinimumUint8Workload =
DataType::QuantisedAsymm8,
MinimumQueueDescriptor,
StringMapping::RefMinimumWorkload_Execute>;
+
+using RefEqualFloat32Workload =
+ RefElementwiseWorkload<std::equal_to<float>,
+ DataType::Float32,
+ EqualQueueDescriptor,
+ StringMapping::RefEqualWorkload_Execute>;
+
+using RefEqualUint8Workload =
+ RefElementwiseWorkload<std::equal_to<float>,
+ DataType::QuantisedAsymm8,
+ EqualQueueDescriptor,
+ StringMapping::RefEqualWorkload_Execute>;
} // armnn