diff options
author | Sadik Armagan <sadik.armagan@arm.com> | 2019-03-25 09:03:35 +0000 |
---|---|---|
committer | Sadik Armagan <sadik.armagan@arm.com> | 2019-03-25 09:03:35 +0000 |
commit | ef38d5d0f071c53883e2b2f13c85bfb3df34bf88 (patch) | |
tree | d9f828f79e3819041a7834b8d7ca5b56d1fd3611 /src/backends/reference/workloads/RefElementwiseWorkload.cpp | |
parent | 6e9482013f41725ccca0767c0c5db9b29f77d981 (diff) | |
download | armnn-ef38d5d0f071c53883e2b2f13c85bfb3df34bf88.tar.gz |
IVGCVSW-2861 Refactor the Reference Elementwise workloads
* Refactored Elementwise Workload into the single workload.
* Execute() function will react based on the DataType.
Change-Id: I6d4d6a74cec150ed8cb252e70b629ed968e7093d
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Diffstat (limited to 'src/backends/reference/workloads/RefElementwiseWorkload.cpp')
-rw-r--r-- | src/backends/reference/workloads/RefElementwiseWorkload.cpp | 161 |
1 files changed, 105 insertions, 56 deletions
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp index c9b93c8524..356d7a0c16 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp @@ -7,69 +7,118 @@ #include "ElementwiseFunction.hpp" #include "RefWorkloadUtils.hpp" #include "Profiling.hpp" +#include "StringMapping.hpp" +#include "TypeUtils.hpp" #include <vector> namespace armnn { -template <typename ParentDescriptor, typename Functor> -void BaseFloat32ElementwiseWorkload<ParentDescriptor, Functor>::ExecuteImpl(const char * debugString) const +template <typename Functor, + typename armnn::DataType DataType, + typename ParentDescriptor, + typename armnn::StringMapping::Id DebugString> +void RefElementwiseWorkload<Functor, DataType, ParentDescriptor, DebugString>::Execute() const { - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString); - - auto data = Float32Workload<ParentDescriptor>::GetData(); - const TensorShape& inShape0 = GetTensorInfo(data.m_Inputs[0]).GetShape(); - const TensorShape& inShape1 = GetTensorInfo(data.m_Inputs[1]).GetShape(); - const TensorShape& outShape = GetTensorInfo(data.m_Outputs[0]).GetShape(); - - const float* inData0 = GetInputTensorDataFloat(0, data); - const float* inData1 = GetInputTensorDataFloat(1, data); - float* outData = GetOutputTensorDataFloat(0, data); - - ElementwiseFunction<Functor, float, float>(inShape0, inShape1, outShape, inData0, inData1, outData); -} - -template <typename ParentDescriptor, typename Functor> -void BaseUint8ElementwiseWorkload<ParentDescriptor, Functor>::ExecuteImpl(const char * debugString) const -{ - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString); - - auto data = Uint8Workload<ParentDescriptor>::GetData(); - const TensorInfo& inputInfo0 = GetTensorInfo(data.m_Inputs[0]); - const TensorInfo& inputInfo1 = GetTensorInfo(data.m_Inputs[1]); - const TensorInfo& outputInfo = GetTensorInfo(data.m_Outputs[0]); - - auto dequant0 = Dequantize(GetInputTensorDataU8(0, data), inputInfo0); - auto dequant1 = Dequantize(GetInputTensorDataU8(1, data), inputInfo1); - - std::vector<float> results(outputInfo.GetNumElements()); - - ElementwiseFunction<Functor, float, float>(inputInfo0.GetShape(), - inputInfo1.GetShape(), - outputInfo.GetShape(), - dequant0.data(), - dequant1.data(), - results.data()); - - Quantize(GetOutputTensorDataU8(0, data), results.data(), outputInfo); + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, StringMapping::Instance().Get(DebugString)); + + const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]); + const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); + + const TensorShape& inShape0 = inputInfo0.GetShape(); + const TensorShape& inShape1 = inputInfo1.GetShape(); + const TensorShape& outShape = outputInfo.GetShape(); + + switch(DataType) + { + case armnn::DataType::QuantisedAsymm8: + { + std::vector<float> results(outputInfo.GetNumElements()); + ElementwiseFunction<Functor, float, float>(inShape0, + inShape1, + outShape, + Dequantize(GetInputTensorDataU8(0, m_Data), inputInfo0).data(), + Dequantize(GetInputTensorDataU8(1, m_Data), inputInfo1).data(), + results.data()); + Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo); + break; + } + case armnn::DataType::Float32: + { + ElementwiseFunction<Functor, float, float>(inShape0, + inShape1, + outShape, + GetInputTensorDataFloat(0, m_Data), + GetInputTensorDataFloat(1, m_Data), + GetOutputTensorDataFloat(0, m_Data)); + break; + } + default: + BOOST_ASSERT_MSG(false, "Unknown Data Type!"); + break; + } } } -template class armnn::BaseFloat32ElementwiseWorkload<armnn::AdditionQueueDescriptor, std::plus<float>>; -template class armnn::BaseUint8ElementwiseWorkload<armnn::AdditionQueueDescriptor, std::plus<float>>; - -template class armnn::BaseFloat32ElementwiseWorkload<armnn::SubtractionQueueDescriptor, std::minus<float>>; -template class armnn::BaseUint8ElementwiseWorkload<armnn::SubtractionQueueDescriptor, std::minus<float>>; - -template class armnn::BaseFloat32ElementwiseWorkload<armnn::MultiplicationQueueDescriptor, std::multiplies<float>>; -template class armnn::BaseUint8ElementwiseWorkload<armnn::MultiplicationQueueDescriptor, std::multiplies<float>>; - -template class armnn::BaseFloat32ElementwiseWorkload<armnn::DivisionQueueDescriptor, std::divides<float>>; -template class armnn::BaseUint8ElementwiseWorkload<armnn::DivisionQueueDescriptor, std::divides<float>>; - -template class armnn::BaseFloat32ElementwiseWorkload<armnn::MaximumQueueDescriptor, armnn::maximum<float>>; -template class armnn::BaseUint8ElementwiseWorkload<armnn::MaximumQueueDescriptor, armnn::maximum<float>>; - -template class armnn::BaseFloat32ElementwiseWorkload<armnn::MinimumQueueDescriptor, armnn::minimum<float>>; -template class armnn::BaseUint8ElementwiseWorkload<armnn::MinimumQueueDescriptor, armnn::minimum<float>>; +template class armnn::RefElementwiseWorkload<std::plus<float>, + armnn::DataType::Float32, + armnn::AdditionQueueDescriptor, + armnn::StringMapping::RefAdditionWorkload_Execute>; + +template class armnn::RefElementwiseWorkload<std::plus<float>, + armnn::DataType::QuantisedAsymm8, + armnn::AdditionQueueDescriptor, + armnn::StringMapping::RefAdditionWorkload_Execute>; + +template class armnn::RefElementwiseWorkload<std::minus<float>, + armnn::DataType::Float32, + armnn::SubtractionQueueDescriptor, + armnn::StringMapping::RefSubtractionWorkload_Execute>; + +template class armnn::RefElementwiseWorkload<std::minus<float>, + armnn::DataType::QuantisedAsymm8, + armnn::SubtractionQueueDescriptor, + armnn::StringMapping::RefSubtractionWorkload_Execute>; + +template class armnn::RefElementwiseWorkload<std::multiplies<float>, + armnn::DataType::Float32, + armnn::MultiplicationQueueDescriptor, + armnn::StringMapping::RefMultiplicationWorkload_Execute>; + +template class armnn::RefElementwiseWorkload<std::multiplies<float>, + armnn::DataType::QuantisedAsymm8, + armnn::MultiplicationQueueDescriptor, + armnn::StringMapping::RefMultiplicationWorkload_Execute>; + +template class armnn::RefElementwiseWorkload<std::divides<float>, + armnn::DataType::Float32, + armnn::DivisionQueueDescriptor, + armnn::StringMapping::RefDivisionWorkload_Execute>; + +template class armnn::RefElementwiseWorkload<std::divides<float>, + armnn::DataType::QuantisedAsymm8, + armnn::DivisionQueueDescriptor, + armnn::StringMapping::RefDivisionWorkload_Execute>; + +template class armnn::RefElementwiseWorkload<armnn::maximum<float>, + armnn::DataType::Float32, + armnn::MaximumQueueDescriptor, + armnn::StringMapping::RefMaximumWorkload_Execute>; + +template class armnn::RefElementwiseWorkload<armnn::maximum<float>, + armnn::DataType::QuantisedAsymm8, + armnn::MaximumQueueDescriptor, + armnn::StringMapping::RefMaximumWorkload_Execute>; + + +template class armnn::RefElementwiseWorkload<armnn::minimum<float>, + armnn::DataType::Float32, + armnn::MinimumQueueDescriptor, + armnn::StringMapping::RefMinimumWorkload_Execute>; + +template class armnn::RefElementwiseWorkload<armnn::minimum<float>, + armnn::DataType::QuantisedAsymm8, + armnn::MinimumQueueDescriptor, + armnn::StringMapping::RefMinimumWorkload_Execute>; |