diff options
author | Sadik Armagan <sadik.armagan@arm.com> | 2019-04-03 17:48:18 +0100 |
---|---|---|
committer | Sadik Armagan <sadik.armagan@arm.com> | 2019-04-08 15:48:28 +0000 |
commit | 2e6dc3a1c5d47825535db7993ba77eb1596ae99b (patch) | |
tree | 48e73fa1862d17534804d1699bedb76120e88c9f /src/backends/reference/workloads/RefElementwiseWorkload.cpp | |
parent | 0324f48e64edb99a5c8d819394545d97e0c2ae97 (diff) | |
download | armnn-2e6dc3a1c5d47825535db7993ba77eb1596ae99b.tar.gz |
IVGCVSW-2861 Refactor the Reference Elementwise workload
* Refactor Reference Comparison workload
* Removed templating based on the DataType
* Implemented BaseIterator to do decode/encode
Change-Id: I18f299f47ee23772f90152c1146b42f07465e105
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Signed-off-by: Kevin May <kevin.may@arm.com>
Diffstat (limited to 'src/backends/reference/workloads/RefElementwiseWorkload.cpp')
-rw-r--r-- | src/backends/reference/workloads/RefElementwiseWorkload.cpp | 115 |
1 files changed, 44 insertions, 71 deletions
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp index 356d7a0c16..6e6e1d5f21 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp @@ -14,14 +14,10 @@ namespace armnn { -template <typename Functor, - typename armnn::DataType DataType, - typename ParentDescriptor, - typename armnn::StringMapping::Id DebugString> -void RefElementwiseWorkload<Functor, DataType, ParentDescriptor, DebugString>::Execute() const +template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString> +void RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::Execute() const { ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, StringMapping::Instance().Get(DebugString)); - const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]); const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); @@ -30,32 +26,46 @@ void RefElementwiseWorkload<Functor, DataType, ParentDescriptor, DebugString>::E const TensorShape& inShape1 = inputInfo1.GetShape(); const TensorShape& outShape = outputInfo.GetShape(); - switch(DataType) + switch(inputInfo0.GetDataType()) { case armnn::DataType::QuantisedAsymm8: { - std::vector<float> results(outputInfo.GetNumElements()); - ElementwiseFunction<Functor, float, float>(inShape0, - inShape1, - outShape, - Dequantize(GetInputTensorDataU8(0, m_Data), inputInfo0).data(), - Dequantize(GetInputTensorDataU8(1, m_Data), inputInfo1).data(), - results.data()); - Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo); + QASymm8Decoder decodeIterator0(GetInputTensorDataU8(0, m_Data), + inputInfo0.GetQuantizationScale(), + inputInfo0.GetQuantizationOffset()); + + QASymm8Decoder decodeIterator1(GetInputTensorDataU8(1, m_Data), + inputInfo1.GetQuantizationScale(), + inputInfo1.GetQuantizationOffset()); + + QASymm8Encoder encodeIterator0(GetOutputTensorDataU8(0, m_Data), + outputInfo.GetQuantizationScale(), + outputInfo.GetQuantizationOffset()); + + ElementwiseFunction<Functor, Decoder, Encoder>(inShape0, + inShape1, + outShape, + decodeIterator0, + decodeIterator1, + encodeIterator0); break; } case armnn::DataType::Float32: { - ElementwiseFunction<Functor, float, float>(inShape0, - inShape1, - outShape, - GetInputTensorDataFloat(0, m_Data), - GetInputTensorDataFloat(1, m_Data), - GetOutputTensorDataFloat(0, m_Data)); + FloatDecoder decodeIterator0(GetInputTensorDataFloat(0, m_Data)); + FloatDecoder decodeIterator1(GetInputTensorDataFloat(1, m_Data)); + FloatEncoder encodeIterator0(GetOutputTensorDataFloat(0, m_Data)); + + ElementwiseFunction<Functor, Decoder, Encoder>(inShape0, + inShape1, + outShape, + decodeIterator0, + decodeIterator1, + encodeIterator0); break; } default: - BOOST_ASSERT_MSG(false, "Unknown Data Type!"); + BOOST_ASSERT_MSG(false, "RefElementwiseWorkload: Not supported Data Type!"); break; } } @@ -63,62 +73,25 @@ void RefElementwiseWorkload<Functor, DataType, ParentDescriptor, DebugString>::E } template class armnn::RefElementwiseWorkload<std::plus<float>, - armnn::DataType::Float32, - armnn::AdditionQueueDescriptor, - armnn::StringMapping::RefAdditionWorkload_Execute>; - -template class armnn::RefElementwiseWorkload<std::plus<float>, - armnn::DataType::QuantisedAsymm8, - armnn::AdditionQueueDescriptor, - armnn::StringMapping::RefAdditionWorkload_Execute>; + armnn::AdditionQueueDescriptor, + armnn::StringMapping::RefAdditionWorkload_Execute>; template class armnn::RefElementwiseWorkload<std::minus<float>, - armnn::DataType::Float32, - armnn::SubtractionQueueDescriptor, - armnn::StringMapping::RefSubtractionWorkload_Execute>; - -template class armnn::RefElementwiseWorkload<std::minus<float>, - armnn::DataType::QuantisedAsymm8, - armnn::SubtractionQueueDescriptor, - armnn::StringMapping::RefSubtractionWorkload_Execute>; - -template class armnn::RefElementwiseWorkload<std::multiplies<float>, - armnn::DataType::Float32, - armnn::MultiplicationQueueDescriptor, - armnn::StringMapping::RefMultiplicationWorkload_Execute>; + armnn::SubtractionQueueDescriptor, + armnn::StringMapping::RefSubtractionWorkload_Execute>; template class armnn::RefElementwiseWorkload<std::multiplies<float>, - armnn::DataType::QuantisedAsymm8, - armnn::MultiplicationQueueDescriptor, - armnn::StringMapping::RefMultiplicationWorkload_Execute>; + armnn::MultiplicationQueueDescriptor, + armnn::StringMapping::RefMultiplicationWorkload_Execute>; template class armnn::RefElementwiseWorkload<std::divides<float>, - armnn::DataType::Float32, - armnn::DivisionQueueDescriptor, - armnn::StringMapping::RefDivisionWorkload_Execute>; - -template class armnn::RefElementwiseWorkload<std::divides<float>, - armnn::DataType::QuantisedAsymm8, - armnn::DivisionQueueDescriptor, - armnn::StringMapping::RefDivisionWorkload_Execute>; - -template class armnn::RefElementwiseWorkload<armnn::maximum<float>, - armnn::DataType::Float32, - armnn::MaximumQueueDescriptor, - armnn::StringMapping::RefMaximumWorkload_Execute>; + armnn::DivisionQueueDescriptor, + armnn::StringMapping::RefDivisionWorkload_Execute>; template class armnn::RefElementwiseWorkload<armnn::maximum<float>, - armnn::DataType::QuantisedAsymm8, - armnn::MaximumQueueDescriptor, - armnn::StringMapping::RefMaximumWorkload_Execute>; - - -template class armnn::RefElementwiseWorkload<armnn::minimum<float>, - armnn::DataType::Float32, - armnn::MinimumQueueDescriptor, - armnn::StringMapping::RefMinimumWorkload_Execute>; + armnn::MaximumQueueDescriptor, + armnn::StringMapping::RefMaximumWorkload_Execute>; template class armnn::RefElementwiseWorkload<armnn::minimum<float>, - armnn::DataType::QuantisedAsymm8, - armnn::MinimumQueueDescriptor, - armnn::StringMapping::RefMinimumWorkload_Execute>; + armnn::MinimumQueueDescriptor, + armnn::StringMapping::RefMinimumWorkload_Execute>;
\ No newline at end of file |