From 2e6dc3a1c5d47825535db7993ba77eb1596ae99b Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Wed, 3 Apr 2019 17:48:18 +0100 Subject: IVGCVSW-2861 Refactor the Reference Elementwise workload * Refactor Reference Comparison workload * Removed templating based on the DataType * Implemented BaseIterator to do decode/encode Change-Id: I18f299f47ee23772f90152c1146b42f07465e105 Signed-off-by: Sadik Armagan Signed-off-by: Kevin May --- .../reference/workloads/RefElementwiseWorkload.cpp | 115 ++++++++------------- 1 file changed, 44 insertions(+), 71 deletions(-) (limited to 'src/backends/reference/workloads/RefElementwiseWorkload.cpp') diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp index 356d7a0c16..6e6e1d5f21 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp @@ -14,14 +14,10 @@ namespace armnn { -template -void RefElementwiseWorkload::Execute() const +template +void RefElementwiseWorkload::Execute() const { ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, StringMapping::Instance().Get(DebugString)); - const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]); const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); @@ -30,32 +26,46 @@ void RefElementwiseWorkload::E const TensorShape& inShape1 = inputInfo1.GetShape(); const TensorShape& outShape = outputInfo.GetShape(); - switch(DataType) + switch(inputInfo0.GetDataType()) { case armnn::DataType::QuantisedAsymm8: { - std::vector results(outputInfo.GetNumElements()); - ElementwiseFunction(inShape0, - inShape1, - outShape, - Dequantize(GetInputTensorDataU8(0, m_Data), inputInfo0).data(), - Dequantize(GetInputTensorDataU8(1, m_Data), inputInfo1).data(), - results.data()); - Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo); + QASymm8Decoder decodeIterator0(GetInputTensorDataU8(0, m_Data), + inputInfo0.GetQuantizationScale(), + inputInfo0.GetQuantizationOffset()); + + QASymm8Decoder decodeIterator1(GetInputTensorDataU8(1, m_Data), + inputInfo1.GetQuantizationScale(), + inputInfo1.GetQuantizationOffset()); + + QASymm8Encoder encodeIterator0(GetOutputTensorDataU8(0, m_Data), + outputInfo.GetQuantizationScale(), + outputInfo.GetQuantizationOffset()); + + ElementwiseFunction(inShape0, + inShape1, + outShape, + decodeIterator0, + decodeIterator1, + encodeIterator0); break; } case armnn::DataType::Float32: { - ElementwiseFunction(inShape0, - inShape1, - outShape, - GetInputTensorDataFloat(0, m_Data), - GetInputTensorDataFloat(1, m_Data), - GetOutputTensorDataFloat(0, m_Data)); + FloatDecoder decodeIterator0(GetInputTensorDataFloat(0, m_Data)); + FloatDecoder decodeIterator1(GetInputTensorDataFloat(1, m_Data)); + FloatEncoder encodeIterator0(GetOutputTensorDataFloat(0, m_Data)); + + ElementwiseFunction(inShape0, + inShape1, + outShape, + decodeIterator0, + decodeIterator1, + encodeIterator0); break; } default: - BOOST_ASSERT_MSG(false, "Unknown Data Type!"); + BOOST_ASSERT_MSG(false, "RefElementwiseWorkload: Not supported Data Type!"); break; } } @@ -63,62 +73,25 @@ void RefElementwiseWorkload::E } template class armnn::RefElementwiseWorkload, - armnn::DataType::Float32, - armnn::AdditionQueueDescriptor, - armnn::StringMapping::RefAdditionWorkload_Execute>; - -template class armnn::RefElementwiseWorkload, - armnn::DataType::QuantisedAsymm8, - armnn::AdditionQueueDescriptor, - armnn::StringMapping::RefAdditionWorkload_Execute>; + armnn::AdditionQueueDescriptor, + armnn::StringMapping::RefAdditionWorkload_Execute>; template class armnn::RefElementwiseWorkload, - armnn::DataType::Float32, - armnn::SubtractionQueueDescriptor, - armnn::StringMapping::RefSubtractionWorkload_Execute>; - -template class armnn::RefElementwiseWorkload, - armnn::DataType::QuantisedAsymm8, - armnn::SubtractionQueueDescriptor, - armnn::StringMapping::RefSubtractionWorkload_Execute>; - -template class armnn::RefElementwiseWorkload, - armnn::DataType::Float32, - armnn::MultiplicationQueueDescriptor, - armnn::StringMapping::RefMultiplicationWorkload_Execute>; + armnn::SubtractionQueueDescriptor, + armnn::StringMapping::RefSubtractionWorkload_Execute>; template class armnn::RefElementwiseWorkload, - armnn::DataType::QuantisedAsymm8, - armnn::MultiplicationQueueDescriptor, - armnn::StringMapping::RefMultiplicationWorkload_Execute>; + armnn::MultiplicationQueueDescriptor, + armnn::StringMapping::RefMultiplicationWorkload_Execute>; template class armnn::RefElementwiseWorkload, - armnn::DataType::Float32, - armnn::DivisionQueueDescriptor, - armnn::StringMapping::RefDivisionWorkload_Execute>; - -template class armnn::RefElementwiseWorkload, - armnn::DataType::QuantisedAsymm8, - armnn::DivisionQueueDescriptor, - armnn::StringMapping::RefDivisionWorkload_Execute>; - -template class armnn::RefElementwiseWorkload, - armnn::DataType::Float32, - armnn::MaximumQueueDescriptor, - armnn::StringMapping::RefMaximumWorkload_Execute>; + armnn::DivisionQueueDescriptor, + armnn::StringMapping::RefDivisionWorkload_Execute>; template class armnn::RefElementwiseWorkload, - armnn::DataType::QuantisedAsymm8, - armnn::MaximumQueueDescriptor, - armnn::StringMapping::RefMaximumWorkload_Execute>; - - -template class armnn::RefElementwiseWorkload, - armnn::DataType::Float32, - armnn::MinimumQueueDescriptor, - armnn::StringMapping::RefMinimumWorkload_Execute>; + armnn::MaximumQueueDescriptor, + armnn::StringMapping::RefMaximumWorkload_Execute>; template class armnn::RefElementwiseWorkload, - armnn::DataType::QuantisedAsymm8, - armnn::MinimumQueueDescriptor, - armnn::StringMapping::RefMinimumWorkload_Execute>; + armnn::MinimumQueueDescriptor, + armnn::StringMapping::RefMinimumWorkload_Execute>; \ No newline at end of file -- cgit v1.2.1