aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/RefComparisonWorkload.cpp
diff options
context:
space:
mode:
authorkevmay01 <kevin.may@arm.com>2019-01-24 14:05:09 +0000
committerkevmay01 <kevin.may@arm.com>2019-01-24 14:05:09 +0000
commit2b4d88e34ac1f965417fd236fd4786f26bae2042 (patch)
tree4518b52c6a22e33c4b467588a2843c9d5f1a9ee6 /src/backends/reference/workloads/RefComparisonWorkload.cpp
parent94412aff782472be54dce4328e2ecee0225b3e97 (diff)
downloadarmnn-2b4d88e34ac1f965417fd236fd4786f26bae2042.tar.gz
IVGCVSW-2503 Refactor RefElementwiseWorkload around Equal and Greater
* Remove Equal and Greater from RefElementwiseWorkload * Create RefComparisonWorkload and add Equal and Greater * Update ElementwiseFunction for different input/output types * Update TfParser to create Equal/Greater with Boolean output * Update relevant tests to check for Boolean comparison Change-Id: I299b7f2121769c960ac0c6139764a5f3c89c9c32
Diffstat (limited to 'src/backends/reference/workloads/RefComparisonWorkload.cpp')
-rw-r--r--src/backends/reference/workloads/RefComparisonWorkload.cpp65
1 files changed, 65 insertions, 0 deletions
diff --git a/src/backends/reference/workloads/RefComparisonWorkload.cpp b/src/backends/reference/workloads/RefComparisonWorkload.cpp
new file mode 100644
index 0000000000..fe517ff51a
--- /dev/null
+++ b/src/backends/reference/workloads/RefComparisonWorkload.cpp
@@ -0,0 +1,65 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "RefComparisonWorkload.hpp"
+#include "ElementwiseFunction.hpp"
+#include "RefWorkloadUtils.hpp"
+#include "Profiling.hpp"
+#include <vector>
+
+namespace armnn {
+
+template<typename ParentDescriptor, typename Functor>
+void RefFloat32ComparisonWorkload<ParentDescriptor, Functor>::ExecuteImpl(const char* debugString) const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString);
+
+ auto data = BaseFloat32ComparisonWorkload<ParentDescriptor>::GetData();
+ const TensorShape& inShape0 = GetTensorInfo(data.m_Inputs[0]).GetShape();
+ const TensorShape& inShape1 = GetTensorInfo(data.m_Inputs[1]).GetShape();
+ const TensorShape& outputShape = GetTensorInfo(data.m_Outputs[0]).GetShape();
+
+ const float* inData0 = GetInputTensorDataFloat(0, data);
+ const float* inData1 = GetInputTensorDataFloat(1, data);
+ uint8_t* outData = GetOutputTensorData<uint8_t>(0, data);
+
+ ElementwiseFunction<Functor, float, uint8_t>(inShape0,
+ inShape1,
+ outputShape,
+ inData0,
+ inData1,
+ outData);
+
+}
+
+template<typename ParentDescriptor, typename Functor>
+void RefUint8ComparisonWorkload<ParentDescriptor, Functor>::ExecuteImpl(const char* debugString) const
+{
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString);
+
+ auto data = BaseUint8ComparisonWorkload<ParentDescriptor>::GetData();
+ const TensorShape& inputInfo0 = GetTensorInfo(data.m_Inputs[0]).GetShape();
+ const TensorShape& inputInfo1 = GetTensorInfo(data.m_Inputs[1]).GetShape();
+ const TensorShape& outputShape = GetTensorInfo(data.m_Outputs[0]).GetShape();
+
+ const uint8_t* inData0 = GetInputTensorData<uint8_t>(0, data);
+ const uint8_t* inData1 = GetInputTensorData<uint8_t>(1, data);
+ uint8_t* outData = GetOutputTensorData<uint8_t>(0, data);
+
+ ElementwiseFunction<Functor, uint8_t, uint8_t>(inputInfo0,
+ inputInfo1,
+ outputShape,
+ inData0,
+ inData1,
+ outData);
+}
+
+}
+
+template class armnn::RefFloat32ComparisonWorkload<armnn::EqualQueueDescriptor, std::equal_to<float>>;
+template class armnn::RefUint8ComparisonWorkload<armnn::EqualQueueDescriptor, std::equal_to<uint8_t>>;
+
+template class armnn::RefFloat32ComparisonWorkload<armnn::GreaterQueueDescriptor, std::greater<float>>;
+template class armnn::RefUint8ComparisonWorkload<armnn::GreaterQueueDescriptor, std::greater<uint8_t>>;