From 2e6dc3a1c5d47825535db7993ba77eb1596ae99b Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Wed, 3 Apr 2019 17:48:18 +0100 Subject: IVGCVSW-2861 Refactor the Reference Elementwise workload * Refactor Reference Comparison workload * Removed templating based on the DataType * Implemented BaseIterator to do decode/encode Change-Id: I18f299f47ee23772f90152c1146b42f07465e105 Signed-off-by: Sadik Armagan Signed-off-by: Kevin May --- .../reference/workloads/RefComparisonWorkload.cpp | 91 ++++++++++++---------- 1 file changed, 51 insertions(+), 40 deletions(-) (limited to 'src/backends/reference/workloads/RefComparisonWorkload.cpp') diff --git a/src/backends/reference/workloads/RefComparisonWorkload.cpp b/src/backends/reference/workloads/RefComparisonWorkload.cpp index fe517ff51a..bb8bb04ad3 100644 --- a/src/backends/reference/workloads/RefComparisonWorkload.cpp +++ b/src/backends/reference/workloads/RefComparisonWorkload.cpp @@ -11,55 +11,66 @@ namespace armnn { -template -void RefFloat32ComparisonWorkload::ExecuteImpl(const char* debugString) const +template +void RefComparisonWorkload::Execute() const { - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString); + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, StringMapping::Instance().Get(DebugString)); + const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]); + const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); - auto data = BaseFloat32ComparisonWorkload::GetData(); - const TensorShape& inShape0 = GetTensorInfo(data.m_Inputs[0]).GetShape(); - const TensorShape& inShape1 = GetTensorInfo(data.m_Inputs[1]).GetShape(); - const TensorShape& outputShape = GetTensorInfo(data.m_Outputs[0]).GetShape(); + const TensorShape& inShape0 = inputInfo0.GetShape(); + const TensorShape& inShape1 = inputInfo1.GetShape(); + const TensorShape& outShape = outputInfo.GetShape(); - const float* inData0 = GetInputTensorDataFloat(0, data); - const float* inData1 = GetInputTensorDataFloat(1, data); - uint8_t* outData = GetOutputTensorData(0, data); + switch(inputInfo0.GetDataType()) + { + case armnn::DataType::QuantisedAsymm8: + { + QASymm8Decoder decodeIterator0(GetInputTensorDataU8(0, m_Data), + inputInfo0.GetQuantizationScale(), + inputInfo0.GetQuantizationOffset()); - ElementwiseFunction(inShape0, - inShape1, - outputShape, - inData0, - inData1, - outData); + QASymm8Decoder decodeIterator1(GetInputTensorDataU8(1, m_Data), + inputInfo1.GetQuantizationScale(), + inputInfo1.GetQuantizationOffset()); -} - -template -void RefUint8ComparisonWorkload::ExecuteImpl(const char* debugString) const -{ - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString); - - auto data = BaseUint8ComparisonWorkload::GetData(); - const TensorShape& inputInfo0 = GetTensorInfo(data.m_Inputs[0]).GetShape(); - const TensorShape& inputInfo1 = GetTensorInfo(data.m_Inputs[1]).GetShape(); - const TensorShape& outputShape = GetTensorInfo(data.m_Outputs[0]).GetShape(); + BooleanEncoder encodeIterator0(GetOutputTensorDataU8(0, m_Data)); - const uint8_t* inData0 = GetInputTensorData(0, data); - const uint8_t* inData1 = GetInputTensorData(1, data); - uint8_t* outData = GetOutputTensorData(0, data); + ElementwiseFunction(inShape0, + inShape1, + outShape, + decodeIterator0, + decodeIterator1, + encodeIterator0); + break; + } + case armnn::DataType::Float32: + { + FloatDecoder decodeIterator0(GetInputTensorDataFloat(0, m_Data)); + FloatDecoder decodeIterator1(GetInputTensorDataFloat(1, m_Data)); + BooleanEncoder encodeIterator0(GetOutputTensorDataU8(0, m_Data)); - ElementwiseFunction(inputInfo0, - inputInfo1, - outputShape, - inData0, - inData1, - outData); + ElementwiseFunction(inShape0, + inShape1, + outShape, + decodeIterator0, + decodeIterator1, + encodeIterator0); + break; + } + default: + BOOST_ASSERT_MSG(false, "RefComparisonWorkload: Not supported Data Type!"); + break; + } } } -template class armnn::RefFloat32ComparisonWorkload>; -template class armnn::RefUint8ComparisonWorkload>; +template class armnn::RefComparisonWorkload, + armnn::EqualQueueDescriptor, + armnn::StringMapping::RefEqualWorkload_Execute>; -template class armnn::RefFloat32ComparisonWorkload>; -template class armnn::RefUint8ComparisonWorkload>; +template class armnn::RefComparisonWorkload, + armnn::GreaterQueueDescriptor, + armnn::StringMapping::RefGreaterWorkload_Execute>; -- cgit v1.2.1