aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference
diff options
context:
space:
mode:
authorFrancisMurtagh <francis.murtagh@arm.com>2018-12-19 10:56:15 +0000
committerFrancisMurtagh <francis.murtagh@arm.com>2018-12-19 10:56:15 +0000
commit878f02313d716e86a95e7f670e9d7b19cdba5e61 (patch)
tree6c80992ad989aca287bec06d713ceb93f1eaf656 /src/backends/reference
parent30cdfcac03fc9f3ab424865b40c0490799e5c8fb (diff)
downloadarmnn-878f02313d716e86a95e7f670e9d7b19cdba5e61.tar.gz
IVGCVSW-2379 Add Greater Ref workload implementation
* Added the Greater operation as an element-wise workload * Added the unit tests Change-Id: Ie00ee30e47a5f5e17a728032eeb11a085d06c8f2
Diffstat (limited to 'src/backends/reference')
-rw-r--r--src/backends/reference/RefLayerSupport.cpp15
-rw-r--r--src/backends/reference/RefLayerSupport.hpp5
-rw-r--r--src/backends/reference/RefWorkloadFactory.cpp2
-rw-r--r--src/backends/reference/test/RefLayerTests.cpp8
-rw-r--r--src/backends/reference/workloads/ElementwiseFunction.cpp3
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.cpp13
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.hpp21
7 files changed, 54 insertions, 13 deletions
diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp
index 2952ae1a80..a64339ec69 100644
--- a/src/backends/reference/RefLayerSupport.cpp
+++ b/src/backends/reference/RefLayerSupport.cpp
@@ -257,6 +257,21 @@ bool RefLayerSupport::IsFullyConnectedSupported(const TensorInfo& input,
&TrueFunc<>);
}
+bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& output,
+ Optional<std::string&> reasonIfUnsupported) const
+{
+ ignore_unused(input0);
+ ignore_unused(input1);
+ ignore_unused(output);
+ ignore_unused(reasonIfUnsupported);
+ return IsSupportedForDataTypeRef(reasonIfUnsupported,
+ input0.GetDataType(),
+ &TrueFunc<>,
+ &TrueFunc<>);
+}
+
bool RefLayerSupport::IsInputSupported(const TensorInfo& input,
Optional<std::string&> reasonIfUnsupported) const
{
diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp
index 399f7b5699..3941f4bc56 100644
--- a/src/backends/reference/RefLayerSupport.hpp
+++ b/src/backends/reference/RefLayerSupport.hpp
@@ -91,6 +91,11 @@ public:
const FullyConnectedDescriptor& descriptor,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+ bool IsGreaterSupported(const TensorInfo& input0,
+ const TensorInfo& input1,
+ const TensorInfo& ouput,
+ Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
+
bool IsInputSupported(const TensorInfo& input,
Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override;
diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp
index 8173bbb952..eb8807eef6 100644
--- a/src/backends/reference/RefWorkloadFactory.cpp
+++ b/src/backends/reference/RefWorkloadFactory.cpp
@@ -303,7 +303,7 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateStridedSlice(const StridedS
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor,
const WorkloadInfo& info) const
{
- return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
+ return MakeWorkload<RefGreaterFloat32Workload, RefGreaterUint8Workload>(descriptor, info);
}
std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor,
diff --git a/src/backends/reference/test/RefLayerTests.cpp b/src/backends/reference/test/RefLayerTests.cpp
index eda58a99b1..6e7da13831 100644
--- a/src/backends/reference/test/RefLayerTests.cpp
+++ b/src/backends/reference/test/RefLayerTests.cpp
@@ -241,6 +241,14 @@ ARMNN_AUTO_TEST_CASE(EqualUint8, EqualUint8Test)
ARMNN_AUTO_TEST_CASE(EqualBroadcast1ElementUint8, EqualBroadcast1ElementUint8Test)
ARMNN_AUTO_TEST_CASE(EqualBroadcast1DVectorUint8, EqualBroadcast1DVectorUint8Test)
+// Greater
+ARMNN_AUTO_TEST_CASE(SimpleGreater, GreaterSimpleTest)
+ARMNN_AUTO_TEST_CASE(GreaterBroadcast1Element, GreaterBroadcast1ElementTest)
+ARMNN_AUTO_TEST_CASE(GreaterBroadcast1DVector, GreaterBroadcast1DVectorTest)
+ARMNN_AUTO_TEST_CASE(GreaterUint8, GreaterUint8Test)
+ARMNN_AUTO_TEST_CASE(GreaterBroadcast1ElementUint8, GreaterBroadcast1ElementUint8Test)
+ARMNN_AUTO_TEST_CASE(GreaterBroadcast1DVectorUint8, GreaterBroadcast1DVectorUint8Test)
+
// Max
ARMNN_AUTO_TEST_CASE(SimpleMaximum, MaximumSimpleTest)
ARMNN_AUTO_TEST_CASE(MaximumBroadcast1Element, MaximumBroadcast1ElementTest)
diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp
index 18ceade113..cb8aa7089c 100644
--- a/src/backends/reference/workloads/ElementwiseFunction.cpp
+++ b/src/backends/reference/workloads/ElementwiseFunction.cpp
@@ -32,4 +32,5 @@ template struct armnn::ElementwiseFunction<std::multiplies<float>>;
template struct armnn::ElementwiseFunction<std::divides<float>>;
template struct armnn::ElementwiseFunction<armnn::maximum<float>>;
template struct armnn::ElementwiseFunction<armnn::minimum<float>>;
-template struct armnn::ElementwiseFunction<std::equal_to<float>>; \ No newline at end of file
+template struct armnn::ElementwiseFunction<std::equal_to<float>>;
+template struct armnn::ElementwiseFunction<std::greater<float>>;
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
index d00bfd01b4..13d6e70a96 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
@@ -45,11 +45,11 @@ void BaseUint8ElementwiseWorkload<ParentDescriptor, Functor>::ExecuteImpl(const
std::vector<float> results(outputInfo.GetNumElements());
ElementwiseFunction<Functor>(inputInfo0.GetShape(),
- inputInfo1.GetShape(),
- outputInfo.GetShape(),
- dequant0.data(),
- dequant1.data(),
- results.data());
+ inputInfo1.GetShape(),
+ outputInfo.GetShape(),
+ dequant0.data(),
+ dequant1.data(),
+ results.data());
Quantize(GetOutputTensorDataU8(0, data), results.data(), outputInfo);
}
@@ -76,3 +76,6 @@ template class armnn::BaseUint8ElementwiseWorkload<armnn::MinimumQueueDescriptor
template class armnn::BaseFloat32ElementwiseWorkload<armnn::EqualQueueDescriptor, std::equal_to<float>>;
template class armnn::BaseUint8ElementwiseWorkload<armnn::EqualQueueDescriptor, std::equal_to<float>>;
+
+template class armnn::BaseFloat32ElementwiseWorkload<armnn::GreaterQueueDescriptor, std::greater<float>>;
+template class armnn::BaseUint8ElementwiseWorkload<armnn::GreaterQueueDescriptor, std::greater<float>>;
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
index c2855b0550..1b3200f85c 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
@@ -12,8 +12,6 @@
#include "Maximum.hpp"
#include "Minimum.hpp"
-
-
namespace armnn
{
@@ -86,7 +84,6 @@ using RefAdditionUint8Workload =
AdditionQueueDescriptor,
StringMapping::RefAdditionWorkload_Execute>;
-
using RefSubtractionFloat32Workload =
RefElementwiseWorkload<std::minus<float>,
DataType::Float32,
@@ -132,9 +129,9 @@ using RefMaximumFloat32Workload =
using RefMaximumUint8Workload =
RefElementwiseWorkload<armnn::maximum<float>,
- DataType::QuantisedAsymm8,
- MaximumQueueDescriptor,
- StringMapping::RefMaximumWorkload_Execute>;
+ DataType::QuantisedAsymm8,
+ MaximumQueueDescriptor,
+ StringMapping::RefMaximumWorkload_Execute>;
using RefMinimumFloat32Workload =
RefElementwiseWorkload<minimum<float>,
@@ -159,4 +156,16 @@ using RefEqualUint8Workload =
DataType::QuantisedAsymm8,
EqualQueueDescriptor,
StringMapping::RefEqualWorkload_Execute>;
+
+using RefGreaterFloat32Workload =
+ RefElementwiseWorkload<std::greater<float>,
+ DataType::Float32,
+ GreaterQueueDescriptor,
+ StringMapping::RefGreaterWorkload_Execute>;
+
+using RefGreaterUint8Workload =
+ RefElementwiseWorkload<std::greater<float>,
+ DataType::QuantisedAsymm8,
+ GreaterQueueDescriptor,
+ StringMapping::RefGreaterWorkload_Execute>;
} // armnn