diff options
author | kevmay01 <kevin.may@arm.com> | 2019-01-24 14:05:09 +0000 |
---|---|---|
committer | kevmay01 <kevin.may@arm.com> | 2019-01-24 14:05:09 +0000 |
commit | 2b4d88e34ac1f965417fd236fd4786f26bae2042 (patch) | |
tree | 4518b52c6a22e33c4b467588a2843c9d5f1a9ee6 /src/backends/reference/workloads/RefComparisonWorkload.cpp | |
parent | 94412aff782472be54dce4328e2ecee0225b3e97 (diff) | |
download | armnn-2b4d88e34ac1f965417fd236fd4786f26bae2042.tar.gz |
IVGCVSW-2503 Refactor RefElementwiseWorkload around Equal and Greater
* Remove Equal and Greater from RefElementwiseWorkload
* Create RefComparisonWorkload and add Equal and Greater
* Update ElementwiseFunction for different input/output types
* Update TfParser to create Equal/Greater with Boolean output
* Update relevant tests to check for Boolean comparison
Change-Id: I299b7f2121769c960ac0c6139764a5f3c89c9c32
Diffstat (limited to 'src/backends/reference/workloads/RefComparisonWorkload.cpp')
-rw-r--r-- | src/backends/reference/workloads/RefComparisonWorkload.cpp | 65 |
1 files changed, 65 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/RefComparisonWorkload.cpp b/src/backends/reference/workloads/RefComparisonWorkload.cpp new file mode 100644 index 0000000000..fe517ff51a --- /dev/null +++ b/src/backends/reference/workloads/RefComparisonWorkload.cpp @@ -0,0 +1,65 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefComparisonWorkload.hpp" +#include "ElementwiseFunction.hpp" +#include "RefWorkloadUtils.hpp" +#include "Profiling.hpp" +#include <vector> + +namespace armnn { + +template<typename ParentDescriptor, typename Functor> +void RefFloat32ComparisonWorkload<ParentDescriptor, Functor>::ExecuteImpl(const char* debugString) const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString); + + auto data = BaseFloat32ComparisonWorkload<ParentDescriptor>::GetData(); + const TensorShape& inShape0 = GetTensorInfo(data.m_Inputs[0]).GetShape(); + const TensorShape& inShape1 = GetTensorInfo(data.m_Inputs[1]).GetShape(); + const TensorShape& outputShape = GetTensorInfo(data.m_Outputs[0]).GetShape(); + + const float* inData0 = GetInputTensorDataFloat(0, data); + const float* inData1 = GetInputTensorDataFloat(1, data); + uint8_t* outData = GetOutputTensorData<uint8_t>(0, data); + + ElementwiseFunction<Functor, float, uint8_t>(inShape0, + inShape1, + outputShape, + inData0, + inData1, + outData); + +} + +template<typename ParentDescriptor, typename Functor> +void RefUint8ComparisonWorkload<ParentDescriptor, Functor>::ExecuteImpl(const char* debugString) const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString); + + auto data = BaseUint8ComparisonWorkload<ParentDescriptor>::GetData(); + const TensorShape& inputInfo0 = GetTensorInfo(data.m_Inputs[0]).GetShape(); + const TensorShape& inputInfo1 = GetTensorInfo(data.m_Inputs[1]).GetShape(); + const TensorShape& outputShape = GetTensorInfo(data.m_Outputs[0]).GetShape(); + + const uint8_t* inData0 = GetInputTensorData<uint8_t>(0, data); + const uint8_t* inData1 = GetInputTensorData<uint8_t>(1, data); + uint8_t* outData = GetOutputTensorData<uint8_t>(0, data); + + ElementwiseFunction<Functor, uint8_t, uint8_t>(inputInfo0, + inputInfo1, + outputShape, + inData0, + inData1, + outData); +} + +} + +template class armnn::RefFloat32ComparisonWorkload<armnn::EqualQueueDescriptor, std::equal_to<float>>; +template class armnn::RefUint8ComparisonWorkload<armnn::EqualQueueDescriptor, std::equal_to<uint8_t>>; + +template class armnn::RefFloat32ComparisonWorkload<armnn::GreaterQueueDescriptor, std::greater<float>>; +template class armnn::RefUint8ComparisonWorkload<armnn::GreaterQueueDescriptor, std::greater<uint8_t>>; |