From 878f02313d716e86a95e7f670e9d7b19cdba5e61 Mon Sep 17 00:00:00 2001 From: FrancisMurtagh Date: Wed, 19 Dec 2018 10:56:15 +0000 Subject: IVGCVSW-2379 Add Greater Ref workload implementation * Added the Greater operation as an element-wise workload * Added the unit tests Change-Id: Ie00ee30e47a5f5e17a728032eeb11a085d06c8f2 --- src/backends/reference/RefLayerSupport.cpp | 15 +++++++++++++++ src/backends/reference/RefLayerSupport.hpp | 5 +++++ src/backends/reference/RefWorkloadFactory.cpp | 2 +- src/backends/reference/test/RefLayerTests.cpp | 8 ++++++++ .../reference/workloads/ElementwiseFunction.cpp | 3 ++- .../reference/workloads/RefElementwiseWorkload.cpp | 13 ++++++++----- .../reference/workloads/RefElementwiseWorkload.hpp | 21 +++++++++++++++------ 7 files changed, 54 insertions(+), 13 deletions(-) (limited to 'src/backends/reference') diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 2952ae1a80..a64339ec69 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -257,6 +257,21 @@ bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input, &TrueFunc<>); } +bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + Optional reasonIfUnsupported) const +{ + ignore_unused(input0); + ignore_unused(input1); + ignore_unused(output); + ignore_unused(reasonIfUnsupported); + return IsSupportedForDataTypeRef(reasonIfUnsupported, + input0.GetDataType(), + &TrueFunc<>, + &TrueFunc<>); +} + bool RefLayerSupport::IsInputSupported(const TensorInfo& input, Optional reasonIfUnsupported) const { diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index 399f7b5699..3941f4bc56 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -91,6 +91,11 @@ public: const FullyConnectedDescriptor& descriptor, Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsGreaterSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& ouput, + Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsInputSupported(const TensorInfo& input, Optional reasonIfUnsupported = EmptyOptional()) const override; diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 8173bbb952..eb8807eef6 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -303,7 +303,7 @@ std::unique_ptr RefWorkloadFactory::CreateStridedSlice(const StridedS std::unique_ptr RefWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload(descriptor, info); + return MakeWorkload(descriptor, info); } std::unique_ptr RefWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor, diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp index eda58a99b1..6e7da13831 100644 --- a/src/backends/reference/test/RefLayerTests.cpp +++ b/src/backends/reference/test/RefLayerTests.cpp @@ -241,6 +241,14 @@ ARMNN_AUTO_TEST_CASE(EqualUint8, EqualUint8Test) ARMNN_AUTO_TEST_CASE(EqualBroadcast1ElementUint8, EqualBroadcast1ElementUint8Test) ARMNN_AUTO_TEST_CASE(EqualBroadcast1DVectorUint8, EqualBroadcast1DVectorUint8Test) +// Greater +ARMNN_AUTO_TEST_CASE(SimpleGreater, GreaterSimpleTest) +ARMNN_AUTO_TEST_CASE(GreaterBroadcast1Element, GreaterBroadcast1ElementTest) +ARMNN_AUTO_TEST_CASE(GreaterBroadcast1DVector, GreaterBroadcast1DVectorTest) +ARMNN_AUTO_TEST_CASE(GreaterUint8, GreaterUint8Test) +ARMNN_AUTO_TEST_CASE(GreaterBroadcast1ElementUint8, GreaterBroadcast1ElementUint8Test) +ARMNN_AUTO_TEST_CASE(GreaterBroadcast1DVectorUint8, GreaterBroadcast1DVectorUint8Test) + // Max ARMNN_AUTO_TEST_CASE(SimpleMaximum, MaximumSimpleTest) ARMNN_AUTO_TEST_CASE(MaximumBroadcast1Element, MaximumBroadcast1ElementTest) diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp index 18ceade113..cb8aa7089c 100644 --- a/src/backends/reference/workloads/ElementwiseFunction.cpp +++ b/src/backends/reference/workloads/ElementwiseFunction.cpp @@ -32,4 +32,5 @@ template struct armnn::ElementwiseFunction>; template struct armnn::ElementwiseFunction>; template struct armnn::ElementwiseFunction>; template struct armnn::ElementwiseFunction>; -template struct armnn::ElementwiseFunction>; \ No newline at end of file +template struct armnn::ElementwiseFunction>; +template struct armnn::ElementwiseFunction>; diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp index d00bfd01b4..13d6e70a96 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp @@ -45,11 +45,11 @@ void BaseUint8ElementwiseWorkload::ExecuteImpl(const std::vector results(outputInfo.GetNumElements()); ElementwiseFunction(inputInfo0.GetShape(), - inputInfo1.GetShape(), - outputInfo.GetShape(), - dequant0.data(), - dequant1.data(), - results.data()); + inputInfo1.GetShape(), + outputInfo.GetShape(), + dequant0.data(), + dequant1.data(), + results.data()); Quantize(GetOutputTensorDataU8(0, data), results.data(), outputInfo); } @@ -76,3 +76,6 @@ template class armnn::BaseUint8ElementwiseWorkload>; template class armnn::BaseUint8ElementwiseWorkload>; + +template class armnn::BaseFloat32ElementwiseWorkload>; +template class armnn::BaseUint8ElementwiseWorkload>; diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp index c2855b0550..1b3200f85c 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp @@ -12,8 +12,6 @@ #include "Maximum.hpp" #include "Minimum.hpp" - - namespace armnn { @@ -86,7 +84,6 @@ using RefAdditionUint8Workload = AdditionQueueDescriptor, StringMapping::RefAdditionWorkload_Execute>; - using RefSubtractionFloat32Workload = RefElementwiseWorkload, DataType::Float32, @@ -132,9 +129,9 @@ using RefMaximumFloat32Workload = using RefMaximumUint8Workload = RefElementwiseWorkload, - DataType::QuantisedAsymm8, - MaximumQueueDescriptor, - StringMapping::RefMaximumWorkload_Execute>; + DataType::QuantisedAsymm8, + MaximumQueueDescriptor, + StringMapping::RefMaximumWorkload_Execute>; using RefMinimumFloat32Workload = RefElementwiseWorkload, @@ -159,4 +156,16 @@ using RefEqualUint8Workload = DataType::QuantisedAsymm8, EqualQueueDescriptor, StringMapping::RefEqualWorkload_Execute>; + +using RefGreaterFloat32Workload = + RefElementwiseWorkload, + DataType::Float32, + GreaterQueueDescriptor, + StringMapping::RefGreaterWorkload_Execute>; + +using RefGreaterUint8Workload = + RefElementwiseWorkload, + DataType::QuantisedAsymm8, + GreaterQueueDescriptor, + StringMapping::RefGreaterWorkload_Execute>; } // armnn -- cgit v1.2.1