aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/RefComparisonWorkload.hpp
diff options
context:
space:
mode:
authorkevmay01 <kevin.may@arm.com>2019-01-24 14:05:09 +0000
committerkevmay01 <kevin.may@arm.com>2019-01-24 14:05:09 +0000
commit2b4d88e34ac1f965417fd236fd4786f26bae2042 (patch)
tree4518b52c6a22e33c4b467588a2843c9d5f1a9ee6 /src/backends/reference/workloads/RefComparisonWorkload.hpp
parent94412aff782472be54dce4328e2ecee0225b3e97 (diff)
downloadarmnn-2b4d88e34ac1f965417fd236fd4786f26bae2042.tar.gz
IVGCVSW-2503 Refactor RefElementwiseWorkload around Equal and Greater
* Remove Equal and Greater from RefElementwiseWorkload * Create RefComparisonWorkload and add Equal and Greater * Update ElementwiseFunction for different input/output types * Update TfParser to create Equal/Greater with Boolean output * Update relevant tests to check for Boolean comparison Change-Id: I299b7f2121769c960ac0c6139764a5f3c89c9c32
Diffstat (limited to 'src/backends/reference/workloads/RefComparisonWorkload.hpp')
-rw-r--r--src/backends/reference/workloads/RefComparisonWorkload.hpp92
1 files changed, 92 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/RefComparisonWorkload.hpp b/src/backends/reference/workloads/RefComparisonWorkload.hpp
new file mode 100644
index 0000000000..524d20625a
--- /dev/null
+++ b/src/backends/reference/workloads/RefComparisonWorkload.hpp
@@ -0,0 +1,92 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/Types.hpp>
+#include <backendsCommon/Workload.hpp>
+#include <backendsCommon/WorkloadData.hpp>
+#include "StringMapping.hpp"
+
+namespace armnn
+{
+
+template <typename Functor,
+ typename armnn::DataType DataType,
+ typename ParentDescriptor,
+ typename armnn::StringMapping::Id DebugString>
+class RefComparisonWorkload
+{
+ // Needs specialization. The default is empty on purpose.
+};
+
+template <typename ParentDescriptor, typename Functor>
+class RefFloat32ComparisonWorkload : public BaseFloat32ComparisonWorkload<ParentDescriptor>
+{
+public:
+ using BaseFloat32ComparisonWorkload<ParentDescriptor>::BaseFloat32ComparisonWorkload;
+ void ExecuteImpl(const char * debugString) const;
+};
+
+template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString>
+class RefComparisonWorkload<Functor, armnn::DataType::Float32, ParentDescriptor, DebugString>
+ : public RefFloat32ComparisonWorkload<ParentDescriptor, Functor>
+{
+public:
+ using RefFloat32ComparisonWorkload<ParentDescriptor, Functor>::RefFloat32ComparisonWorkload;
+
+ virtual void Execute() const override
+ {
+ using Parent = RefFloat32ComparisonWorkload<ParentDescriptor, Functor>;
+ Parent::ExecuteImpl(StringMapping::Instance().Get(DebugString));
+ }
+};
+
+template <typename ParentDescriptor, typename Functor>
+class RefUint8ComparisonWorkload : public BaseUint8ComparisonWorkload<ParentDescriptor>
+{
+public:
+ using BaseUint8ComparisonWorkload<ParentDescriptor>::BaseUint8ComparisonWorkload;
+ void ExecuteImpl(const char * debugString) const;
+};
+
+template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString>
+class RefComparisonWorkload<Functor, armnn::DataType::QuantisedAsymm8, ParentDescriptor, DebugString>
+ : public RefUint8ComparisonWorkload<ParentDescriptor, Functor>
+{
+public:
+ using RefUint8ComparisonWorkload<ParentDescriptor, Functor>::RefUint8ComparisonWorkload;
+
+ virtual void Execute() const override
+ {
+ using Parent = RefUint8ComparisonWorkload<ParentDescriptor, Functor>;
+ Parent::ExecuteImpl(StringMapping::Instance().Get(DebugString));
+ }
+};
+
+using RefEqualFloat32Workload =
+ RefComparisonWorkload<std::equal_to<float>,
+ DataType::Float32,
+ EqualQueueDescriptor,
+ StringMapping::RefEqualWorkload_Execute>;
+
+using RefEqualUint8Workload =
+ RefComparisonWorkload<std::equal_to<uint8_t>,
+ DataType::QuantisedAsymm8,
+ EqualQueueDescriptor,
+ StringMapping::RefEqualWorkload_Execute>;
+
+using RefGreaterFloat32Workload =
+ RefComparisonWorkload<std::greater<float>,
+ DataType::Float32,
+ GreaterQueueDescriptor,
+ StringMapping::RefGreaterWorkload_Execute>;
+
+using RefGreaterUint8Workload =
+ RefComparisonWorkload<std::greater<uint8_t>,
+ DataType::QuantisedAsymm8,
+ GreaterQueueDescriptor,
+ StringMapping::RefGreaterWorkload_Execute>;
+} // armnn