aboutsummaryrefslogtreecommitdiff
path: root/src/backends/reference/workloads/RefComparisonWorkload.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/reference/workloads/RefComparisonWorkload.cpp')
-rw-r--r--src/backends/reference/workloads/RefComparisonWorkload.cpp91
1 files changed, 51 insertions, 40 deletions
diff --git a/src/backends/reference/workloads/RefComparisonWorkload.cpp b/src/backends/reference/workloads/RefComparisonWorkload.cpp
index fe517ff51a..bb8bb04ad3 100644
--- a/src/backends/reference/workloads/RefComparisonWorkload.cpp
+++ b/src/backends/reference/workloads/RefComparisonWorkload.cpp
@@ -11,55 +11,66 @@
namespace armnn {
-template<typename ParentDescriptor, typename Functor>
-void RefFloat32ComparisonWorkload<ParentDescriptor, Functor>::ExecuteImpl(const char* debugString) const
+template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString>
+void RefComparisonWorkload<Functor, ParentDescriptor, DebugString>::Execute() const
{
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString);
+ ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, StringMapping::Instance().Get(DebugString));
+ const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]);
+ const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]);
+ const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
- auto data = BaseFloat32ComparisonWorkload<ParentDescriptor>::GetData();
- const TensorShape& inShape0 = GetTensorInfo(data.m_Inputs[0]).GetShape();
- const TensorShape& inShape1 = GetTensorInfo(data.m_Inputs[1]).GetShape();
- const TensorShape& outputShape = GetTensorInfo(data.m_Outputs[0]).GetShape();
+ const TensorShape& inShape0 = inputInfo0.GetShape();
+ const TensorShape& inShape1 = inputInfo1.GetShape();
+ const TensorShape& outShape = outputInfo.GetShape();
- const float* inData0 = GetInputTensorDataFloat(0, data);
- const float* inData1 = GetInputTensorDataFloat(1, data);
- uint8_t* outData = GetOutputTensorData<uint8_t>(0, data);
+ switch(inputInfo0.GetDataType())
+ {
+ case armnn::DataType::QuantisedAsymm8:
+ {
+ QASymm8Decoder decodeIterator0(GetInputTensorDataU8(0, m_Data),
+ inputInfo0.GetQuantizationScale(),
+ inputInfo0.GetQuantizationOffset());
- ElementwiseFunction<Functor, float, uint8_t>(inShape0,
- inShape1,
- outputShape,
- inData0,
- inData1,
- outData);
+ QASymm8Decoder decodeIterator1(GetInputTensorDataU8(1, m_Data),
+ inputInfo1.GetQuantizationScale(),
+ inputInfo1.GetQuantizationOffset());
-}
-
-template<typename ParentDescriptor, typename Functor>
-void RefUint8ComparisonWorkload<ParentDescriptor, Functor>::ExecuteImpl(const char* debugString) const
-{
- ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString);
-
- auto data = BaseUint8ComparisonWorkload<ParentDescriptor>::GetData();
- const TensorShape& inputInfo0 = GetTensorInfo(data.m_Inputs[0]).GetShape();
- const TensorShape& inputInfo1 = GetTensorInfo(data.m_Inputs[1]).GetShape();
- const TensorShape& outputShape = GetTensorInfo(data.m_Outputs[0]).GetShape();
+ BooleanEncoder encodeIterator0(GetOutputTensorDataU8(0, m_Data));
- const uint8_t* inData0 = GetInputTensorData<uint8_t>(0, data);
- const uint8_t* inData1 = GetInputTensorData<uint8_t>(1, data);
- uint8_t* outData = GetOutputTensorData<uint8_t>(0, data);
+ ElementwiseFunction<Functor, Decoder, ComparisonEncoder>(inShape0,
+ inShape1,
+ outShape,
+ decodeIterator0,
+ decodeIterator1,
+ encodeIterator0);
+ break;
+ }
+ case armnn::DataType::Float32:
+ {
+ FloatDecoder decodeIterator0(GetInputTensorDataFloat(0, m_Data));
+ FloatDecoder decodeIterator1(GetInputTensorDataFloat(1, m_Data));
+ BooleanEncoder encodeIterator0(GetOutputTensorDataU8(0, m_Data));
- ElementwiseFunction<Functor, uint8_t, uint8_t>(inputInfo0,
- inputInfo1,
- outputShape,
- inData0,
- inData1,
- outData);
+ ElementwiseFunction<Functor, Decoder, ComparisonEncoder>(inShape0,
+ inShape1,
+ outShape,
+ decodeIterator0,
+ decodeIterator1,
+ encodeIterator0);
+ break;
+ }
+ default:
+ BOOST_ASSERT_MSG(false, "RefComparisonWorkload: Not supported Data Type!");
+ break;
+ }
}
}
-template class armnn::RefFloat32ComparisonWorkload<armnn::EqualQueueDescriptor, std::equal_to<float>>;
-template class armnn::RefUint8ComparisonWorkload<armnn::EqualQueueDescriptor, std::equal_to<uint8_t>>;
+template class armnn::RefComparisonWorkload<std::equal_to<float>,
+ armnn::EqualQueueDescriptor,
+ armnn::StringMapping::RefEqualWorkload_Execute>;
-template class armnn::RefFloat32ComparisonWorkload<armnn::GreaterQueueDescriptor, std::greater<float>>;
-template class armnn::RefUint8ComparisonWorkload<armnn::GreaterQueueDescriptor, std::greater<uint8_t>>;
+template class armnn::RefComparisonWorkload<std::greater<float>,
+ armnn::GreaterQueueDescriptor,
+ armnn::StringMapping::RefGreaterWorkload_Execute>;