From 77bfb5e32faadb1383d48364a6f54adbff84ad80 Mon Sep 17 00:00:00 2001 From: Aron Virginas-Tar Date: Wed, 16 Oct 2019 17:45:38 +0100 Subject: IVGCVSW-3993 Add frontend and reference workload for ComparisonLayer * Added frontend for ComparisonLayer * Added RefComparisonWorkload * Deprecated and removed Equal and Greater layers and workloads * Updated tests to ensure backward compatibility Signed-off-by: Aron Virginas-Tar Change-Id: Id50c880be1b567c531efff919c0c366d0a71cbe9 --- src/backends/cl/ClLayerSupport.cpp | 25 ++++++++++++++++++++----- src/backends/cl/ClLayerSupport.hpp | 7 +++++++ src/backends/cl/ClWorkloadFactory.cpp | 28 ++++++++++++++++++++++++++-- src/backends/cl/ClWorkloadFactory.hpp | 3 +++ src/backends/cl/test/ClEndToEndTests.cpp | 26 +++++++++++++------------- 5 files changed, 69 insertions(+), 20 deletions(-) (limited to 'src/backends/cl') diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp index c5ed8bff2a..bd2be57386 100644 --- a/src/backends/cl/ClLayerSupport.cpp +++ b/src/backends/cl/ClLayerSupport.cpp @@ -209,6 +209,24 @@ bool ClLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input, descriptor); } +bool ClLayerSupport::IsComparisonSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + const ComparisonDescriptor& descriptor, + Optional reasonIfUnsupported) const +{ + if (descriptor.m_Operation == ComparisonOperation::Greater) + { + FORWARD_WORKLOAD_VALIDATE_FUNC(ClGreaterWorkloadValidate, + reasonIfUnsupported, + input0, + input1, + output); + } + + return false; +} + bool ClLayerSupport::IsConcatSupported(const std::vector inputs, const TensorInfo& output, const ConcatDescriptor& descriptor, @@ -398,11 +416,8 @@ bool ClLayerSupport::IsGreaterSupported(const TensorInfo& input0, const TensorInfo& output, Optional reasonIfUnsupported) const { - FORWARD_WORKLOAD_VALIDATE_FUNC(ClGreaterWorkloadValidate, - reasonIfUnsupported, - input0, - input1, - output); + ComparisonDescriptor descriptor(ComparisonOperation::Greater); + return IsComparisonSupported(input0, input1, output, descriptor, reasonIfUnsupported); } bool ClLayerSupport::IsInputSupported(const TensorInfo& input, diff --git a/src/backends/cl/ClLayerSupport.hpp b/src/backends/cl/ClLayerSupport.hpp index 59e849316f..26eb42e092 100644 --- a/src/backends/cl/ClLayerSupport.hpp +++ b/src/backends/cl/ClLayerSupport.hpp @@ -40,6 +40,12 @@ public: const BatchToSpaceNdDescriptor& descriptor, Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsComparisonSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& ouput, + const ComparisonDescriptor& descriptor, + Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsConcatSupported(const std::vector inputs, const TensorInfo& output, const ConcatDescriptor& descriptor, @@ -102,6 +108,7 @@ public: const FullyConnectedDescriptor& descriptor, Optional reasonIfUnsupported = EmptyOptional()) const override; + ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead") bool IsGreaterSupported(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& ouput, diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index c427ae7e12..04e09f4ff1 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -157,6 +157,20 @@ std::unique_ptr ClWorkloadFactory::CreateBatchToSpaceNd(const BatchTo return MakeWorkload(descriptor, info); } +std::unique_ptr ClWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + if (descriptor.m_Parameters.m_Operation == ComparisonOperation::Greater) + { + GreaterQueueDescriptor greaterQueueDescriptor; + greaterQueueDescriptor.m_Inputs = descriptor.m_Inputs; + greaterQueueDescriptor.m_Outputs = descriptor.m_Outputs; + + return MakeWorkload(greaterQueueDescriptor, info); + } + return MakeWorkload(descriptor, info); +} + std::unique_ptr ClWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor, const WorkloadInfo& info) const { @@ -230,7 +244,12 @@ std::unique_ptr ClWorkloadFactory::CreateDivision(const DivisionQueue std::unique_ptr ClWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload(descriptor, info); + boost::ignore_unused(descriptor); + + ComparisonQueueDescriptor comparisonDescriptor; + comparisonDescriptor.m_Parameters = ComparisonDescriptor(ComparisonOperation::Equal); + + return CreateComparison(comparisonDescriptor, info); } std::unique_ptr ClWorkloadFactory::CreateFakeQuantization( @@ -261,7 +280,12 @@ std::unique_ptr ClWorkloadFactory::CreateGather(const GatherQueueDesc std::unique_ptr ClWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload(descriptor, info); + boost::ignore_unused(descriptor); + + ComparisonQueueDescriptor comparisonDescriptor; + comparisonDescriptor.m_Parameters = ComparisonDescriptor(ComparisonOperation::Greater); + + return CreateComparison(comparisonDescriptor, info); } std::unique_ptr ClWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor, diff --git a/src/backends/cl/ClWorkloadFactory.hpp b/src/backends/cl/ClWorkloadFactory.hpp index 9dbc615a4e..1cae6e1faf 100644 --- a/src/backends/cl/ClWorkloadFactory.hpp +++ b/src/backends/cl/ClWorkloadFactory.hpp @@ -53,6 +53,9 @@ public: std::unique_ptr CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + std::unique_ptr CreateComparison(const ComparisonQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; + std::unique_ptr CreateConcat(const ConcatQueueDescriptor& descriptor, const WorkloadInfo& info) const override; diff --git a/src/backends/cl/test/ClEndToEndTests.cpp b/src/backends/cl/test/ClEndToEndTests.cpp index 59d26edf22..26f15b77da 100644 --- a/src/backends/cl/test/ClEndToEndTests.cpp +++ b/src/backends/cl/test/ClEndToEndTests.cpp @@ -6,7 +6,7 @@ #include #include -#include +#include #include #include #include @@ -122,9 +122,9 @@ BOOST_AUTO_TEST_CASE(ClGreaterSimpleEndToEndTest) const std::vector expectedOutput({ 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 }); - ArithmeticSimpleEndToEnd(defaultBackends, - LayerType::Greater, - expectedOutput); + ComparisonSimpleEndToEnd(defaultBackends, + ComparisonOperation::Greater, + expectedOutput); } BOOST_AUTO_TEST_CASE(ClGreaterSimpleEndToEndUint8Test) @@ -132,9 +132,9 @@ BOOST_AUTO_TEST_CASE(ClGreaterSimpleEndToEndUint8Test) const std::vector expectedOutput({ 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 }); - ArithmeticSimpleEndToEnd(defaultBackends, - LayerType::Greater, - expectedOutput); + ComparisonSimpleEndToEnd(defaultBackends, + ComparisonOperation::Greater, + expectedOutput); } BOOST_AUTO_TEST_CASE(ClGreaterBroadcastEndToEndTest) @@ -142,9 +142,9 @@ BOOST_AUTO_TEST_CASE(ClGreaterBroadcastEndToEndTest) const std::vector expectedOutput({ 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1 }); - ArithmeticBroadcastEndToEnd(defaultBackends, - LayerType::Greater, - expectedOutput); + ComparisonBroadcastEndToEnd(defaultBackends, + ComparisonOperation::Greater, + expectedOutput); } BOOST_AUTO_TEST_CASE(ClGreaterBroadcastEndToEndUint8Test) @@ -152,9 +152,9 @@ BOOST_AUTO_TEST_CASE(ClGreaterBroadcastEndToEndUint8Test) const std::vector expectedOutput({ 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1 }); - ArithmeticBroadcastEndToEnd(defaultBackends, - LayerType::Greater, - expectedOutput); + ComparisonBroadcastEndToEnd(defaultBackends, + ComparisonOperation::Greater, + expectedOutput); } // InstanceNormalization -- cgit v1.2.1