aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads
diff options
context:
space:
mode:
authorAron Virginas-Tar <Aron.Virginas-Tar@arm.com>2019-10-16 17:45:38 +0100
committerÁron Virginás-Tar <aron.virginas-tar@arm.com>2019-10-21 08:52:04 +0000
commit77bfb5e32faadb1383d48364a6f54adbff84ad80 (patch)
tree0bf5dfb48cb8d5c248baf716f02b9f481400316e /src/backends/reference/workloads
parent5884708e650a80e355398532bc320bbabdbb53f4 (diff)
downloadarmnn-77bfb5e32faadb1383d48364a6f54adbff84ad80.tar.gz
IVGCVSW-3993 Add frontend and reference workload for ComparisonLayer
* Added frontend for ComparisonLayer * Added RefComparisonWorkload * Deprecated and removed Equal and Greater layers and workloads * Updated tests to ensure backward compatibility Signed-off-by: Aron Virginas-Tar <Aron.Virginas-Tar@arm.com> Change-Id: Id50c880be1b567c531efff919c0c366d0a71cbe9
Diffstat (limited to 'src/backends/reference/workloads')
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt2
-rw-r--r--src/backends/reference/workloads/ElementwiseFunction.cpp7
-rw-r--r--src/backends/reference/workloads/RefComparisonWorkload.cpp102
-rw-r--r--src/backends/reference/workloads/RefComparisonWorkload.hpp34
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.cpp8
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.hpp9
-rw-r--r--src/backends/reference/workloads/RefWorkloads.hpp1
-rw-r--r--src/backends/reference/workloads/StringMapping.hpp4
8 files changed, 145 insertions, 22 deletions
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index b8eb95c729..7844518620 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -63,6 +63,8 @@ list(APPEND armnnRefBackendWorkloads_sources
RefBatchNormalizationWorkload.hpp
RefBatchToSpaceNdWorkload.cpp
RefBatchToSpaceNdWorkload.hpp
+ RefComparisonWorkload.cpp
+ RefComparisonWorkload.hpp
RefConcatWorkload.cpp
RefConcatWorkload.hpp
RefConstantWorkload.cpp
diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp
index 7a5c071f70..888037f9a6 100644
--- a/src/backends/reference/workloads/ElementwiseFunction.cpp
+++ b/src/backends/reference/workloads/ElementwiseFunction.cpp
@@ -32,6 +32,11 @@ template struct armnn::ElementwiseFunction<std::multiplies<float>>;
template struct armnn::ElementwiseFunction<std::divides<float>>;
template struct armnn::ElementwiseFunction<armnn::maximum<float>>;
template struct armnn::ElementwiseFunction<armnn::minimum<float>>;
+
+// Comparison
template struct armnn::ElementwiseFunction<std::equal_to<float>>;
template struct armnn::ElementwiseFunction<std::greater<float>>;
-
+template struct armnn::ElementwiseFunction<std::greater_equal<float>>;
+template struct armnn::ElementwiseFunction<std::less<float>>;
+template struct armnn::ElementwiseFunction<std::less_equal<float>>;
+template struct armnn::ElementwiseFunction<std::not_equal_to<float>>;
diff --git a/src/backends/reference/workloads/RefComparisonWorkload.cpp b/src/backends/reference/workloads/RefComparisonWorkload.cpp
new file mode 100644
index 0000000000..60446226be
--- /dev/null
+++ b/src/backends/reference/workloads/RefComparisonWorkload.cpp
@@ -0,0 +1,102 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefComparisonWorkload.hpp"
+
+#include "Decoders.hpp"
+#include "ElementwiseFunction.hpp"
+#include "Encoders.hpp"
+#include "RefWorkloadUtils.hpp"
+
+#include <Profiling.hpp>
+
+#include <armnn/TypesUtils.hpp>
+
+#include <functional>
+
+namespace armnn
+{
+
+RefComparisonWorkload::RefComparisonWorkload(const ComparisonQueueDescriptor& desc,
+ const WorkloadInfo& info)
+ : BaseWorkload<ComparisonQueueDescriptor>(desc, info)
+{}
+
+void RefComparisonWorkload::PostAllocationConfigure()
+{
+ const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ m_Input0 = MakeDecoder<InType>(inputInfo0);
+ m_Input1 = MakeDecoder<InType>(inputInfo1);
+
+ m_Output = MakeEncoder<OutType>(outputInfo);
+}
+
+void RefComparisonWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefComparisonWorkload_Execute");
+
+ const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
+
+ const TensorShape& inShape0 = inputInfo0.GetShape();
+ const TensorShape& inShape1 = inputInfo1.GetShape();
+ const TensorShape& outShape = outputInfo.GetShape();
+
+ m_Input0->Reset(m_Data.m_Inputs[0]->Map());
+ m_Input1->Reset(m_Data.m_Inputs[1]->Map());
+ m_Output->Reset(m_Data.m_Outputs[0]->Map());
+
+ using EqualFunction = ElementwiseFunction<std::equal_to<InType>>;
+ using GreaterFunction = ElementwiseFunction<std::greater<InType>>;
+ using GreaterOrEqualFunction = ElementwiseFunction<std::greater_equal<InType>>;
+ using LessFunction = ElementwiseFunction<std::less<InType>>;
+ using LessOrEqualFunction = ElementwiseFunction<std::less_equal<InType>>;
+ using NotEqualFunction = ElementwiseFunction<std::not_equal_to<InType>>;
+
+ switch (m_Data.m_Parameters.m_Operation)
+ {
+ case ComparisonOperation::Equal:
+ {
+ EqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
+ break;
+ }
+ case ComparisonOperation::Greater:
+ {
+ GreaterFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
+ break;
+ }
+ case ComparisonOperation::GreaterOrEqual:
+ {
+ GreaterOrEqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
+ break;
+ }
+ case ComparisonOperation::Less:
+ {
+ LessFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
+ break;
+ }
+ case ComparisonOperation::LessOrEqual:
+ {
+ LessOrEqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
+ break;
+ }
+ case ComparisonOperation::NotEqual:
+ {
+ NotEqualFunction(inShape0, inShape1, outShape, *m_Input0, *m_Input1, *m_Output);
+ break;
+ }
+ default:
+ {
+ throw InvalidArgumentException(std::string("Unsupported comparison operation ") +
+ GetComparisonOperationAsCString(m_Data.m_Parameters.m_Operation), CHECK_LOCATION());
+ }
+ }
+}
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefComparisonWorkload.hpp b/src/backends/reference/workloads/RefComparisonWorkload.hpp
new file mode 100644
index 0000000000..a19e4a0540
--- /dev/null
+++ b/src/backends/reference/workloads/RefComparisonWorkload.hpp
@@ -0,0 +1,34 @@
+//
+// Copyright © 2019 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include "BaseIterator.hpp"
+
+#include <backendsCommon/Workload.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+
+namespace armnn
+{
+
+class RefComparisonWorkload : public BaseWorkload<ComparisonQueueDescriptor>
+{
+public:
+ using BaseWorkload<ComparisonQueueDescriptor>::m_Data;
+
+ RefComparisonWorkload(const ComparisonQueueDescriptor& descriptor, const WorkloadInfo& info);
+ void PostAllocationConfigure() override;
+ void Execute() const override;
+
+private:
+ using InType = float;
+ using OutType = bool;
+
+ std::unique_ptr<Decoder<InType>> m_Input0;
+ std::unique_ptr<Decoder<InType>> m_Input1;
+ std::unique_ptr<Encoder<OutType>> m_Output;
+};
+
+} // namespace armnn
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
index 6431348bc2..7e02f032ef 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
@@ -86,11 +86,3 @@ template class armnn::RefElementwiseWorkload<armnn::maximum<float>,
template class armnn::RefElementwiseWorkload<armnn::minimum<float>,
armnn::MinimumQueueDescriptor,
armnn::StringMapping::RefMinimumWorkload_Execute>;
-
-template class armnn::RefElementwiseWorkload<std::equal_to<float>,
- armnn::EqualQueueDescriptor,
- armnn::StringMapping::RefEqualWorkload_Execute>;
-
-template class armnn::RefElementwiseWorkload<std::greater<float>,
- armnn::GreaterQueueDescriptor,
- armnn::StringMapping::RefGreaterWorkload_Execute>;
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
index 651942e9e5..ee0d80b172 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
@@ -65,13 +65,4 @@ using RefMinimumWorkload =
MinimumQueueDescriptor,
StringMapping::RefMinimumWorkload_Execute>;
-using RefEqualWorkload =
- RefElementwiseWorkload<std::equal_to<float>,
- armnn::EqualQueueDescriptor,
- armnn::StringMapping::RefEqualWorkload_Execute>;
-
-using RefGreaterWorkload =
- RefElementwiseWorkload<std::greater<float>,
- armnn::GreaterQueueDescriptor,
- armnn::StringMapping::RefGreaterWorkload_Execute>;
} // armnn
diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp
index 79d1935823..1f9ad4a19a 100644
--- a/src/backends/reference/workloads/RefWorkloads.hpp
+++ b/src/backends/reference/workloads/RefWorkloads.hpp
@@ -20,6 +20,7 @@
#include "RefArgMinMaxWorkload.hpp"
#include "RefBatchNormalizationWorkload.hpp"
#include "RefBatchToSpaceNdWorkload.hpp"
+#include "RefComparisonWorkload.hpp"
#include "RefConvolution2dWorkload.hpp"
#include "RefConstantWorkload.hpp"
#include "RefConcatWorkload.hpp"
diff --git a/src/backends/reference/workloads/StringMapping.hpp b/src/backends/reference/workloads/StringMapping.hpp
index 073a5a6833..1654b78088 100644
--- a/src/backends/reference/workloads/StringMapping.hpp
+++ b/src/backends/reference/workloads/StringMapping.hpp
@@ -18,9 +18,7 @@ struct StringMapping
public:
enum Id {
RefAdditionWorkload_Execute,
- RefEqualWorkload_Execute,
RefDivisionWorkload_Execute,
- RefGreaterWorkload_Execute,
RefMaximumWorkload_Execute,
RefMinimumWorkload_Execute,
RefMultiplicationWorkload_Execute,
@@ -40,8 +38,6 @@ private:
{
m_Strings[RefAdditionWorkload_Execute] = "RefAdditionWorkload_Execute";
m_Strings[RefDivisionWorkload_Execute] = "RefDivisionWorkload_Execute";
- m_Strings[RefEqualWorkload_Execute] = "RefEqualWorkload_Execute";
- m_Strings[RefGreaterWorkload_Execute] = "RefGreaterWorkload_Execute";
m_Strings[RefMaximumWorkload_Execute] = "RefMaximumWorkload_Execute";
m_Strings[RefMinimumWorkload_Execute] = "RefMinimumWorkload_Execute";
m_Strings[RefMultiplicationWorkload_Execute] = "RefMultiplicationWorkload_Execute";