From 77bfb5e32faadb1383d48364a6f54adbff84ad80 Mon Sep 17 00:00:00 2001 From: Aron Virginas-Tar Date: Wed, 16 Oct 2019 17:45:38 +0100 Subject: IVGCVSW-3993 Add frontend and reference workload for ComparisonLayer * Added frontend for ComparisonLayer * Added RefComparisonWorkload * Deprecated and removed Equal and Greater layers and workloads * Updated tests to ensure backward compatibility Signed-off-by: Aron Virginas-Tar Change-Id: Id50c880be1b567c531efff919c0c366d0a71cbe9 --- src/backends/reference/RefLayerSupport.cpp | 85 ++++++++--------- src/backends/reference/RefLayerSupport.hpp | 8 ++ src/backends/reference/RefWorkloadFactory.cpp | 16 +++- src/backends/reference/RefWorkloadFactory.hpp | 5 + src/backends/reference/backend.mk | 1 + src/backends/reference/test/RefEndToEndTests.cpp | 50 +++++----- src/backends/reference/workloads/CMakeLists.txt | 2 + .../reference/workloads/ElementwiseFunction.cpp | 7 +- .../reference/workloads/RefComparisonWorkload.cpp | 102 +++++++++++++++++++++ .../reference/workloads/RefComparisonWorkload.hpp | 34 +++++++ .../reference/workloads/RefElementwiseWorkload.cpp | 8 -- .../reference/workloads/RefElementwiseWorkload.hpp | 9 -- src/backends/reference/workloads/RefWorkloads.hpp | 1 + src/backends/reference/workloads/StringMapping.hpp | 4 - 14 files changed, 237 insertions(+), 95 deletions(-) create mode 100644 src/backends/reference/workloads/RefComparisonWorkload.cpp create mode 100644 src/backends/reference/workloads/RefComparisonWorkload.hpp (limited to 'src/backends/reference') diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index 9342b29f47..c65886ba4d 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -308,6 +308,35 @@ bool RefLayerSupport::IsBatchToSpaceNdSupported(const TensorInfo& input, return supported; } +bool RefLayerSupport::IsComparisonSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + const ComparisonDescriptor& descriptor, + Optional reasonIfUnsupported) const +{ + boost::ignore_unused(descriptor); + + std::array supportedInputTypes = + { + DataType::Float32, + DataType::Float16, + DataType::QuantisedAsymm8, + DataType::QuantisedSymm16 + }; + + bool supported = true; + supported &= CheckSupportRule(TypeAnyOf(input0, supportedInputTypes), reasonIfUnsupported, + "Reference comparison: input 0 is not a supported type"); + + supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported, + "Reference comparison: input 0 and Input 1 types are mismatched"); + + supported &= CheckSupportRule(TypeIs(output, DataType::Boolean), reasonIfUnsupported, + "Reference comparison: output is not of type Boolean"); + + return supported; +} + bool RefLayerSupport::IsConcatSupported(const std::vector inputs, const TensorInfo& output, const ConcatDescriptor& descriptor, @@ -644,29 +673,11 @@ bool RefLayerSupport::IsEqualSupported(const TensorInfo& input0, const TensorInfo& output, Optional reasonIfUnsupported) const { - bool supported = true; - - std::array supportedTypes = - { - DataType::Float32, - DataType::Float16, - DataType::QuantisedAsymm8, - DataType::QuantisedSymm16 - }; - - supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported, - "Reference equal: input 0 is not a supported type."); - - supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported, - "Reference equal: input 1 is not a supported type."); - - supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported, - "Reference equal: input 0 and Input 1 types are mismatched"); - - supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported, - "Reference equal: shapes are not suitable for implicit broadcast."); - - return supported; + return IsComparisonSupported(input0, + input1, + output, + ComparisonDescriptor(ComparisonOperation::Equal), + reasonIfUnsupported); } bool RefLayerSupport::IsFakeQuantizationSupported(const TensorInfo& input, @@ -802,29 +813,11 @@ bool RefLayerSupport::IsGreaterSupported(const TensorInfo& input0, const TensorInfo& output, Optional reasonIfUnsupported) const { - bool supported = true; - - std::array supportedTypes = - { - DataType::Float32, - DataType::Float16, - DataType::QuantisedAsymm8, - DataType::QuantisedSymm16 - }; - - supported &= CheckSupportRule(TypeAnyOf(input0, supportedTypes), reasonIfUnsupported, - "Reference greater: input 0 is not a supported type."); - - supported &= CheckSupportRule(TypeAnyOf(input1, supportedTypes), reasonIfUnsupported, - "Reference greater: input 1 is not a supported type."); - - supported &= CheckSupportRule(TypesAreEqual(input0, input1), reasonIfUnsupported, - "Reference greater: input 0 and Input 1 types are mismatched"); - - supported &= CheckSupportRule(ShapesAreBroadcastCompatible(input0, input1, output), reasonIfUnsupported, - "Reference greater: shapes are not suitable for implicit broadcast."); - - return supported; + return IsComparisonSupported(input0, + input1, + output, + ComparisonDescriptor(ComparisonOperation::Greater), + reasonIfUnsupported); } bool RefLayerSupport::IsInputSupported(const TensorInfo& input, diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index 5c71e8d337..04b355ee0a 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -45,6 +45,12 @@ public: const BatchToSpaceNdDescriptor& descriptor, Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsComparisonSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + const ComparisonDescriptor& descriptor, + Optional reasonIfUnsupported = EmptyOptional()) const override; + bool IsConcatSupported(const std::vector inputs, const TensorInfo& output, const ConcatDescriptor& descriptor, @@ -106,6 +112,7 @@ public: const TensorInfo& output, Optional reasonIfUnsupported = EmptyOptional()) const override; + ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead") bool IsEqualSupported(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, @@ -131,6 +138,7 @@ public: const TensorInfo& output, Optional reasonIfUnsupported = EmptyOptional()) const override; + ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead") bool IsGreaterSupported(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 1f6d1d7e8b..c2cb51abf3 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -131,6 +131,12 @@ std::unique_ptr RefWorkloadFactory::CreateBatchToSpaceNd(const BatchT return std::make_unique(descriptor, info); } +std::unique_ptr RefWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::make_unique(descriptor, info); +} + std::unique_ptr RefWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor, const WorkloadInfo& info) const { @@ -208,7 +214,10 @@ std::unique_ptr RefWorkloadFactory::CreateDivision(const DivisionQueu std::unique_ptr RefWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + ComparisonQueueDescriptor comparisonDescriptor; + comparisonDescriptor.m_Parameters.m_Operation = ComparisonOperation::Equal; + + return CreateComparison(comparisonDescriptor, info); } std::unique_ptr RefWorkloadFactory::CreateFakeQuantization( @@ -240,7 +249,10 @@ std::unique_ptr RefWorkloadFactory::CreateGather(const GatherQueueDes std::unique_ptr RefWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return std::make_unique(descriptor, info); + ComparisonQueueDescriptor comparisonDescriptor; + comparisonDescriptor.m_Parameters.m_Operation = ComparisonOperation::Greater; + + return CreateComparison(comparisonDescriptor, info); } std::unique_ptr RefWorkloadFactory::CreateInput(const InputQueueDescriptor& descriptor, diff --git a/src/backends/reference/RefWorkloadFactory.hpp b/src/backends/reference/RefWorkloadFactory.hpp index 41e9b28ea2..7b73d5b21f 100644 --- a/src/backends/reference/RefWorkloadFactory.hpp +++ b/src/backends/reference/RefWorkloadFactory.hpp @@ -78,6 +78,9 @@ public: std::unique_ptr CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + std::unique_ptr CreateComparison(const ComparisonQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; + std::unique_ptr CreateConcat(const ConcatQueueDescriptor& descriptor, const WorkloadInfo& info) const override; @@ -111,6 +114,7 @@ public: std::unique_ptr CreateDivision(const DivisionQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + ARMNN_DEPRECATED_MSG("Use CreateComparison instead") std::unique_ptr CreateEqual(const EqualQueueDescriptor& descriptor, const WorkloadInfo& info) const override; @@ -126,6 +130,7 @@ public: std::unique_ptr CreateGather(const GatherQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + ARMNN_DEPRECATED_MSG("Use CreateComparison instead") std::unique_ptr CreateGreater(const GreaterQueueDescriptor& descriptor, const WorkloadInfo& info) const override; diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk index 49b07a41d2..7e97acdee2 100644 --- a/src/backends/reference/backend.mk +++ b/src/backends/reference/backend.mk @@ -47,6 +47,7 @@ BACKEND_SOURCES := \ workloads/RefArgMinMaxWorkload.cpp \ workloads/RefBatchNormalizationWorkload.cpp \ workloads/RefBatchToSpaceNdWorkload.cpp \ + workloads/RefComparisonWorkload.cpp \ workloads/RefConcatWorkload.cpp \ workloads/RefConstantWorkload.cpp \ workloads/RefConvertFp16ToFp32Workload.cpp \ diff --git a/src/backends/reference/test/RefEndToEndTests.cpp b/src/backends/reference/test/RefEndToEndTests.cpp index 370ef6599b..1968e4da7e 100644 --- a/src/backends/reference/test/RefEndToEndTests.cpp +++ b/src/backends/reference/test/RefEndToEndTests.cpp @@ -6,8 +6,8 @@ #include #include -#include #include +#include #include #include #include @@ -348,9 +348,9 @@ BOOST_AUTO_TEST_CASE(RefEqualSimpleEndToEndTest) const std::vector expectedOutput({ 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1 }); - ArithmeticSimpleEndToEnd(defaultBackends, - LayerType::Equal, - expectedOutput); + ComparisonSimpleEndToEnd(defaultBackends, + ComparisonOperation::Equal, + expectedOutput); } BOOST_AUTO_TEST_CASE(RefGreaterSimpleEndToEndTest) @@ -358,9 +358,9 @@ BOOST_AUTO_TEST_CASE(RefGreaterSimpleEndToEndTest) const std::vector expectedOutput({ 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 }); - ArithmeticSimpleEndToEnd(defaultBackends, - LayerType::Greater, - expectedOutput); + ComparisonSimpleEndToEnd(defaultBackends, + ComparisonOperation::Greater, + expectedOutput); } BOOST_AUTO_TEST_CASE(RefEqualSimpleEndToEndUint8Test) @@ -368,9 +368,9 @@ BOOST_AUTO_TEST_CASE(RefEqualSimpleEndToEndUint8Test) const std::vector expectedOutput({ 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1 }); - ArithmeticSimpleEndToEnd(defaultBackends, - LayerType::Equal, - expectedOutput); + ComparisonSimpleEndToEnd(defaultBackends, + ComparisonOperation::Equal, + expectedOutput); } BOOST_AUTO_TEST_CASE(RefGreaterSimpleEndToEndUint8Test) @@ -378,9 +378,9 @@ BOOST_AUTO_TEST_CASE(RefGreaterSimpleEndToEndUint8Test) const std::vector expectedOutput({ 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 }); - ArithmeticSimpleEndToEnd(defaultBackends, - LayerType::Greater, - expectedOutput); + ComparisonSimpleEndToEnd(defaultBackends, + ComparisonOperation::Greater, + expectedOutput); } BOOST_AUTO_TEST_CASE(RefEqualBroadcastEndToEndTest) @@ -388,9 +388,9 @@ BOOST_AUTO_TEST_CASE(RefEqualBroadcastEndToEndTest) const std::vector expectedOutput({ 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 }); - ArithmeticBroadcastEndToEnd(defaultBackends, - LayerType::Equal, - expectedOutput); + ComparisonBroadcastEndToEnd(defaultBackends, + ComparisonOperation::Equal, + expectedOutput); } BOOST_AUTO_TEST_CASE(RefGreaterBroadcastEndToEndTest) @@ -398,9 +398,9 @@ BOOST_AUTO_TEST_CASE(RefGreaterBroadcastEndToEndTest) const std::vector expectedOutput({ 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1 }); - ArithmeticBroadcastEndToEnd(defaultBackends, - LayerType::Greater, - expectedOutput); + ComparisonBroadcastEndToEnd(defaultBackends, + ComparisonOperation::Greater, + expectedOutput); } BOOST_AUTO_TEST_CASE(RefEqualBroadcastEndToEndUint8Test) @@ -408,9 +408,9 @@ BOOST_AUTO_TEST_CASE(RefEqualBroadcastEndToEndUint8Test) const std::vector expectedOutput({ 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0 }); - ArithmeticBroadcastEndToEnd(defaultBackends, - LayerType::Equal, - expectedOutput); + ComparisonBroadcastEndToEnd(defaultBackends, + ComparisonOperation::Equal, + expectedOutput); } BOOST_AUTO_TEST_CASE(RefGreaterBroadcastEndToEndUint8Test) @@ -418,9 +418,9 @@ BOOST_AUTO_TEST_CASE(RefGreaterBroadcastEndToEndUint8Test) const std::vector expectedOutput({ 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1 }); - ArithmeticBroadcastEndToEnd(defaultBackends, - LayerType::Greater, - expectedOutput); + ComparisonBroadcastEndToEnd(defaultBackends, + ComparisonOperation::Greater, + expectedOutput); } BOOST_AUTO_TEST_CASE(RefBatchToSpaceNdEndToEndFloat32NHWCTest) diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index b8eb95c729..7844518620 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -63,6 +63,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefBatchNormalizationWorkload.hpp RefBatchToSpaceNdWorkload.cpp RefBatchToSpaceNdWorkload.hpp + RefComparisonWorkload.cpp + RefComparisonWorkload.hpp RefConcatWorkload.cpp RefConcatWorkload.hpp RefConstantWorkload.cpp diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp index 7a5c071f70..888037f9a6 100644 --- a/src/backends/reference/workloads/ElementwiseFunction.cpp +++ b/src/backends/reference/workloads/ElementwiseFunction.cpp @@ -32,6 +32,11 @@ template struct armnn::ElementwiseFunction>; template struct armnn::ElementwiseFunction>; template struct armnn::ElementwiseFunction>; template struct armnn::ElementwiseFunction>; + +// Comparison template struct armnn::ElementwiseFunction>; template struct armnn::ElementwiseFunction>; - +template struct armnn::ElementwiseFunction>; +template struct armnn::ElementwiseFunction>; +template struct armnn::ElementwiseFunction>; +template struct armnn::ElementwiseFunction>; diff --git a/src/backends/reference/workloads/RefComparisonWorkload.cpp b/src/backends/reference/workloads/RefComparisonWorkload.cpp new file mode 100644 index 0000000000..60446226be --- /dev/null +++ b/src/backends/reference/workloads/RefComparisonWorkload.cpp @@ -0,0 +1,102 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefComparisonWorkload.hpp" + +#include "Decoders.hpp" +#include "ElementwiseFunction.hpp" +#include "Encoders.hpp" +#include "RefWorkloadUtils.hpp" + +#include + +#include + +#include + +namespace armnn +{ + +RefComparisonWorkload::RefComparisonWorkload(const ComparisonQueueDescriptor& desc, + const WorkloadInfo& info) + : BaseWorkload(desc, info) +{} + +void RefComparisonWorkload::PostAllocationConfigure() +{ + const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]); + const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); + + m_Input0 = MakeDecoder(inputInfo0); + m_Input1 = MakeDecoder(inputInfo1); + + m_Output = MakeEncoder(outputInfo); +} + +void RefComparisonWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefComparisonWorkload_Execute"); + + const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]); + const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); + + const TensorShape& inShape0 = inputInfo0.GetShape(); + const TensorShape& inShape1 = inputInfo1.GetShape(); + const TensorShape& outShape = outputInfo.GetShape(); + + m_Input0->Reset(m_Data.m_Inputs[0]->Map()); + m_Input1->Reset(m_Data.m_Inputs[1]->Map()); + m_Output->Reset(m_Data.m_Outputs[0]->Map()); + + using EqualFunction = ElementwiseFunction>; + using GreaterFunction = ElementwiseFunction>; + using GreaterOrEqualFunction = ElementwiseFunction>; + using LessFunction = ElementwiseFunction>; + using LessOrEqualFunction = ElementwiseFunction>; + using NotEqualFunction = ElementwiseFunction>; + + switch (m_Data.m_Parameters.m_Operation) + { + case ComparisonOperation::Equal: + { + EqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output); + break; + } + case ComparisonOperation::Greater: + { + GreaterFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output); + break; + } + case ComparisonOperation::GreaterOrEqual: + { + GreaterOrEqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output); + break; + } + case ComparisonOperation::Less: + { + LessFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output); + break; + } + case ComparisonOperation::LessOrEqual: + { + LessOrEqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output); + break; + } + case ComparisonOperation::NotEqual: + { + NotEqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output); + break; + } + default: + { + throw InvalidArgumentException(std::string("Unsupported comparison operation ") + + GetComparisonOperationAsCString(m_Data.m_Parameters.m_Operation), CHECK_LOCATION()); + } + } +} + +} // namespace armnn diff --git a/src/backends/reference/workloads/RefComparisonWorkload.hpp b/src/backends/reference/workloads/RefComparisonWorkload.hpp new file mode 100644 index 0000000000..a19e4a0540 --- /dev/null +++ b/src/backends/reference/workloads/RefComparisonWorkload.hpp @@ -0,0 +1,34 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "BaseIterator.hpp" + +#include +#include + +namespace armnn +{ + +class RefComparisonWorkload : public BaseWorkload +{ +public: + using BaseWorkload::m_Data; + + RefComparisonWorkload(const ComparisonQueueDescriptor& descriptor, const WorkloadInfo& info); + void PostAllocationConfigure() override; + void Execute() const override; + +private: + using InType = float; + using OutType = bool; + + std::unique_ptr> m_Input0; + std::unique_ptr> m_Input1; + std::unique_ptr> m_Output; +}; + +} // namespace armnn diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp index 6431348bc2..7e02f032ef 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp @@ -86,11 +86,3 @@ template class armnn::RefElementwiseWorkload, template class armnn::RefElementwiseWorkload, armnn::MinimumQueueDescriptor, armnn::StringMapping::RefMinimumWorkload_Execute>; - -template class armnn::RefElementwiseWorkload, - armnn::EqualQueueDescriptor, - armnn::StringMapping::RefEqualWorkload_Execute>; - -template class armnn::RefElementwiseWorkload, - armnn::GreaterQueueDescriptor, - armnn::StringMapping::RefGreaterWorkload_Execute>; diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp index 651942e9e5..ee0d80b172 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp @@ -65,13 +65,4 @@ using RefMinimumWorkload = MinimumQueueDescriptor, StringMapping::RefMinimumWorkload_Execute>; -using RefEqualWorkload = - RefElementwiseWorkload, - armnn::EqualQueueDescriptor, - armnn::StringMapping::RefEqualWorkload_Execute>; - -using RefGreaterWorkload = - RefElementwiseWorkload, - armnn::GreaterQueueDescriptor, - armnn::StringMapping::RefGreaterWorkload_Execute>; } // armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index 79d1935823..1f9ad4a19a 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -20,6 +20,7 @@ #include "RefArgMinMaxWorkload.hpp" #include "RefBatchNormalizationWorkload.hpp" #include "RefBatchToSpaceNdWorkload.hpp" +#include "RefComparisonWorkload.hpp" #include "RefConvolution2dWorkload.hpp" #include "RefConstantWorkload.hpp" #include "RefConcatWorkload.hpp" diff --git a/src/backends/reference/workloads/StringMapping.hpp b/src/backends/reference/workloads/StringMapping.hpp index 073a5a6833..1654b78088 100644 --- a/src/backends/reference/workloads/StringMapping.hpp +++ b/src/backends/reference/workloads/StringMapping.hpp @@ -18,9 +18,7 @@ struct StringMapping public: enum Id { RefAdditionWorkload_Execute, - RefEqualWorkload_Execute, RefDivisionWorkload_Execute, - RefGreaterWorkload_Execute, RefMaximumWorkload_Execute, RefMinimumWorkload_Execute, RefMultiplicationWorkload_Execute, @@ -40,8 +38,6 @@ private: { m_Strings[RefAdditionWorkload_Execute] = "RefAdditionWorkload_Execute"; m_Strings[RefDivisionWorkload_Execute] = "RefDivisionWorkload_Execute"; - m_Strings[RefEqualWorkload_Execute] = "RefEqualWorkload_Execute"; - m_Strings[RefGreaterWorkload_Execute] = "RefGreaterWorkload_Execute"; m_Strings[RefMaximumWorkload_Execute] = "RefMaximumWorkload_Execute"; m_Strings[RefMinimumWorkload_Execute] = "RefMinimumWorkload_Execute"; m_Strings[RefMultiplicationWorkload_Execute] = "RefMultiplicationWorkload_Execute"; -- cgit v1.2.1