From 2b4d88e34ac1f965417fd236fd4786f26bae2042 Mon Sep 17 00:00:00 2001 From: kevmay01 Date: Thu, 24 Jan 2019 14:05:09 +0000 Subject: IVGCVSW-2503 Refactor RefElementwiseWorkload around Equal and Greater * Remove Equal and Greater from RefElementwiseWorkload * Create RefComparisonWorkload and add Equal and Greater * Update ElementwiseFunction for different input/output types * Update TfParser to create Equal/Greater with Boolean output * Update relevant tests to check for Boolean comparison Change-Id: I299b7f2121769c960ac0c6139764a5f3c89c9c32 --- .../reference/workloads/RefComparisonWorkload.hpp | 92 ++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 src/backends/reference/workloads/RefComparisonWorkload.hpp (limited to 'src/backends/reference/workloads/RefComparisonWorkload.hpp') diff --git a/src/backends/reference/workloads/RefComparisonWorkload.hpp b/src/backends/reference/workloads/RefComparisonWorkload.hpp new file mode 100644 index 0000000000..524d20625a --- /dev/null +++ b/src/backends/reference/workloads/RefComparisonWorkload.hpp @@ -0,0 +1,92 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include +#include +#include +#include "StringMapping.hpp" + +namespace armnn +{ + +template +class RefComparisonWorkload +{ + // Needs specialization. The default is empty on purpose. +}; + +template +class RefFloat32ComparisonWorkload : public BaseFloat32ComparisonWorkload +{ +public: + using BaseFloat32ComparisonWorkload::BaseFloat32ComparisonWorkload; + void ExecuteImpl(const char * debugString) const; +}; + +template +class RefComparisonWorkload + : public RefFloat32ComparisonWorkload +{ +public: + using RefFloat32ComparisonWorkload::RefFloat32ComparisonWorkload; + + virtual void Execute() const override + { + using Parent = RefFloat32ComparisonWorkload; + Parent::ExecuteImpl(StringMapping::Instance().Get(DebugString)); + } +}; + +template +class RefUint8ComparisonWorkload : public BaseUint8ComparisonWorkload +{ +public: + using BaseUint8ComparisonWorkload::BaseUint8ComparisonWorkload; + void ExecuteImpl(const char * debugString) const; +}; + +template +class RefComparisonWorkload + : public RefUint8ComparisonWorkload +{ +public: + using RefUint8ComparisonWorkload::RefUint8ComparisonWorkload; + + virtual void Execute() const override + { + using Parent = RefUint8ComparisonWorkload; + Parent::ExecuteImpl(StringMapping::Instance().Get(DebugString)); + } +}; + +using RefEqualFloat32Workload = + RefComparisonWorkload, + DataType::Float32, + EqualQueueDescriptor, + StringMapping::RefEqualWorkload_Execute>; + +using RefEqualUint8Workload = + RefComparisonWorkload, + DataType::QuantisedAsymm8, + EqualQueueDescriptor, + StringMapping::RefEqualWorkload_Execute>; + +using RefGreaterFloat32Workload = + RefComparisonWorkload, + DataType::Float32, + GreaterQueueDescriptor, + StringMapping::RefGreaterWorkload_Execute>; + +using RefGreaterUint8Workload = + RefComparisonWorkload, + DataType::QuantisedAsymm8, + GreaterQueueDescriptor, + StringMapping::RefGreaterWorkload_Execute>; +} // armnn -- cgit v1.2.1