From 77bfb5e32faadb1383d48364a6f54adbff84ad80 Mon Sep 17 00:00:00 2001 From: Aron Virginas-Tar Date: Wed, 16 Oct 2019 17:45:38 +0100 Subject: IVGCVSW-3993 Add frontend and reference workload for ComparisonLayer * Added frontend for ComparisonLayer * Added RefComparisonWorkload * Deprecated and removed Equal and Greater layers and workloads * Updated tests to ensure backward compatibility Signed-off-by: Aron Virginas-Tar Change-Id: Id50c880be1b567c531efff919c0c366d0a71cbe9 --- src/backends/reference/workloads/CMakeLists.txt | 2 + .../reference/workloads/ElementwiseFunction.cpp | 7 +- .../reference/workloads/RefComparisonWorkload.cpp | 102 +++++++++++++++++++++ .../reference/workloads/RefComparisonWorkload.hpp | 34 +++++++ .../reference/workloads/RefElementwiseWorkload.cpp | 8 -- .../reference/workloads/RefElementwiseWorkload.hpp | 9 -- src/backends/reference/workloads/RefWorkloads.hpp | 1 + src/backends/reference/workloads/StringMapping.hpp | 4 - 8 files changed, 145 insertions(+), 22 deletions(-) create mode 100644 src/backends/reference/workloads/RefComparisonWorkload.cpp create mode 100644 src/backends/reference/workloads/RefComparisonWorkload.hpp (limited to 'src/backends/reference/workloads') diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index b8eb95c729..7844518620 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -63,6 +63,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefBatchNormalizationWorkload.hpp RefBatchToSpaceNdWorkload.cpp RefBatchToSpaceNdWorkload.hpp + RefComparisonWorkload.cpp + RefComparisonWorkload.hpp RefConcatWorkload.cpp RefConcatWorkload.hpp RefConstantWorkload.cpp diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp index 7a5c071f70..888037f9a6 100644 --- a/src/backends/reference/workloads/ElementwiseFunction.cpp +++ b/src/backends/reference/workloads/ElementwiseFunction.cpp @@ -32,6 +32,11 @@ template struct armnn::ElementwiseFunction>; template struct armnn::ElementwiseFunction>; template struct armnn::ElementwiseFunction>; template struct armnn::ElementwiseFunction>; + +// Comparison template struct armnn::ElementwiseFunction>; template struct armnn::ElementwiseFunction>; - +template struct armnn::ElementwiseFunction>; +template struct armnn::ElementwiseFunction>; +template struct armnn::ElementwiseFunction>; +template struct armnn::ElementwiseFunction>; diff --git a/src/backends/reference/workloads/RefComparisonWorkload.cpp b/src/backends/reference/workloads/RefComparisonWorkload.cpp new file mode 100644 index 0000000000..60446226be --- /dev/null +++ b/src/backends/reference/workloads/RefComparisonWorkload.cpp @@ -0,0 +1,102 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefComparisonWorkload.hpp" + +#include "Decoders.hpp" +#include "ElementwiseFunction.hpp" +#include "Encoders.hpp" +#include "RefWorkloadUtils.hpp" + +#include + +#include + +#include + +namespace armnn +{ + +RefComparisonWorkload::RefComparisonWorkload(const ComparisonQueueDescriptor& desc, + const WorkloadInfo& info) + : BaseWorkload(desc, info) +{} + +void RefComparisonWorkload::PostAllocationConfigure() +{ + const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]); + const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); + + m_Input0 = MakeDecoder(inputInfo0); + m_Input1 = MakeDecoder(inputInfo1); + + m_Output = MakeEncoder(outputInfo); +} + +void RefComparisonWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefComparisonWorkload_Execute"); + + const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]); + const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); + + const TensorShape& inShape0 = inputInfo0.GetShape(); + const TensorShape& inShape1 = inputInfo1.GetShape(); + const TensorShape& outShape = outputInfo.GetShape(); + + m_Input0->Reset(m_Data.m_Inputs[0]->Map()); + m_Input1->Reset(m_Data.m_Inputs[1]->Map()); + m_Output->Reset(m_Data.m_Outputs[0]->Map()); + + using EqualFunction = ElementwiseFunction>; + using GreaterFunction = ElementwiseFunction>; + using GreaterOrEqualFunction = ElementwiseFunction>; + using LessFunction = ElementwiseFunction>; + using LessOrEqualFunction = ElementwiseFunction>; + using NotEqualFunction = ElementwiseFunction>; + + switch (m_Data.m_Parameters.m_Operation) + { + case ComparisonOperation::Equal: + { + EqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output); + break; + } + case ComparisonOperation::Greater: + { + GreaterFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output); + break; + } + case ComparisonOperation::GreaterOrEqual: + { + GreaterOrEqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output); + break; + } + case ComparisonOperation::Less: + { + LessFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output); + break; + } + case ComparisonOperation::LessOrEqual: + { + LessOrEqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output); + break; + } + case ComparisonOperation::NotEqual: + { + NotEqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output); + break; + } + default: + { + throw InvalidArgumentException(std::string("Unsupported comparison operation ") + + GetComparisonOperationAsCString(m_Data.m_Parameters.m_Operation), CHECK_LOCATION()); + } + } +} + +} // namespace armnn diff --git a/src/backends/reference/workloads/RefComparisonWorkload.hpp b/src/backends/reference/workloads/RefComparisonWorkload.hpp new file mode 100644 index 0000000000..a19e4a0540 --- /dev/null +++ b/src/backends/reference/workloads/RefComparisonWorkload.hpp @@ -0,0 +1,34 @@ +// +// Copyright © 2019 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "BaseIterator.hpp" + +#include +#include + +namespace armnn +{ + +class RefComparisonWorkload : public BaseWorkload +{ +public: + using BaseWorkload::m_Data; + + RefComparisonWorkload(const ComparisonQueueDescriptor& descriptor, const WorkloadInfo& info); + void PostAllocationConfigure() override; + void Execute() const override; + +private: + using InType = float; + using OutType = bool; + + std::unique_ptr> m_Input0; + std::unique_ptr> m_Input1; + std::unique_ptr> m_Output; +}; + +} // namespace armnn diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp index 6431348bc2..7e02f032ef 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp @@ -86,11 +86,3 @@ template class armnn::RefElementwiseWorkload, template class armnn::RefElementwiseWorkload, armnn::MinimumQueueDescriptor, armnn::StringMapping::RefMinimumWorkload_Execute>; - -template class armnn::RefElementwiseWorkload, - armnn::EqualQueueDescriptor, - armnn::StringMapping::RefEqualWorkload_Execute>; - -template class armnn::RefElementwiseWorkload, - armnn::GreaterQueueDescriptor, - armnn::StringMapping::RefGreaterWorkload_Execute>; diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp index 651942e9e5..ee0d80b172 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp @@ -65,13 +65,4 @@ using RefMinimumWorkload = MinimumQueueDescriptor, StringMapping::RefMinimumWorkload_Execute>; -using RefEqualWorkload = - RefElementwiseWorkload, - armnn::EqualQueueDescriptor, - armnn::StringMapping::RefEqualWorkload_Execute>; - -using RefGreaterWorkload = - RefElementwiseWorkload, - armnn::GreaterQueueDescriptor, - armnn::StringMapping::RefGreaterWorkload_Execute>; } // armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index 79d1935823..1f9ad4a19a 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -20,6 +20,7 @@ #include "RefArgMinMaxWorkload.hpp" #include "RefBatchNormalizationWorkload.hpp" #include "RefBatchToSpaceNdWorkload.hpp" +#include "RefComparisonWorkload.hpp" #include "RefConvolution2dWorkload.hpp" #include "RefConstantWorkload.hpp" #include "RefConcatWorkload.hpp" diff --git a/src/backends/reference/workloads/StringMapping.hpp b/src/backends/reference/workloads/StringMapping.hpp index 073a5a6833..1654b78088 100644 --- a/src/backends/reference/workloads/StringMapping.hpp +++ b/src/backends/reference/workloads/StringMapping.hpp @@ -18,9 +18,7 @@ struct StringMapping public: enum Id { RefAdditionWorkload_Execute, - RefEqualWorkload_Execute, RefDivisionWorkload_Execute, - RefGreaterWorkload_Execute, RefMaximumWorkload_Execute, RefMinimumWorkload_Execute, RefMultiplicationWorkload_Execute, @@ -40,8 +38,6 @@ private: { m_Strings[RefAdditionWorkload_Execute] = "RefAdditionWorkload_Execute"; m_Strings[RefDivisionWorkload_Execute] = "RefDivisionWorkload_Execute"; - m_Strings[RefEqualWorkload_Execute] = "RefEqualWorkload_Execute"; - m_Strings[RefGreaterWorkload_Execute] = "RefGreaterWorkload_Execute"; m_Strings[RefMaximumWorkload_Execute] = "RefMaximumWorkload_Execute"; m_Strings[RefMinimumWorkload_Execute] = "RefMinimumWorkload_Execute"; m_Strings[RefMultiplicationWorkload_Execute] = "RefMultiplicationWorkload_Execute"; -- cgit v1.2.1