diff options
author | Sadik Armagan <sadik.armagan@arm.com> | 2019-04-03 17:48:18 +0100 |
---|---|---|
committer | Sadik Armagan <sadik.armagan@arm.com> | 2019-04-08 15:48:28 +0000 |
commit | 2e6dc3a1c5d47825535db7993ba77eb1596ae99b (patch) | |
tree | 48e73fa1862d17534804d1699bedb76120e88c9f /src/backends/reference/workloads/RefComparisonWorkload.cpp | |
parent | 0324f48e64edb99a5c8d819394545d97e0c2ae97 (diff) | |
download | armnn-2e6dc3a1c5d47825535db7993ba77eb1596ae99b.tar.gz |
IVGCVSW-2861 Refactor the Reference Elementwise workload
* Refactor Reference Comparison workload
* Removed templating based on the DataType
* Implemented BaseIterator to do decode/encode
Change-Id: I18f299f47ee23772f90152c1146b42f07465e105
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Signed-off-by: Kevin May <kevin.may@arm.com>
Diffstat (limited to 'src/backends/reference/workloads/RefComparisonWorkload.cpp')
-rw-r--r-- | src/backends/reference/workloads/RefComparisonWorkload.cpp | 91 |
1 files changed, 51 insertions, 40 deletions
diff --git a/src/backends/reference/workloads/RefComparisonWorkload.cpp b/src/backends/reference/workloads/RefComparisonWorkload.cpp index fe517ff51a..bb8bb04ad3 100644 --- a/src/backends/reference/workloads/RefComparisonWorkload.cpp +++ b/src/backends/reference/workloads/RefComparisonWorkload.cpp @@ -11,55 +11,66 @@ namespace armnn { -template<typename ParentDescriptor, typename Functor> -void RefFloat32ComparisonWorkload<ParentDescriptor, Functor>::ExecuteImpl(const char* debugString) const +template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString> +void RefComparisonWorkload<Functor, ParentDescriptor, DebugString>::Execute() const { - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString); + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, StringMapping::Instance().Get(DebugString)); + const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]); + const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); - auto data = BaseFloat32ComparisonWorkload<ParentDescriptor>::GetData(); - const TensorShape& inShape0 = GetTensorInfo(data.m_Inputs[0]).GetShape(); - const TensorShape& inShape1 = GetTensorInfo(data.m_Inputs[1]).GetShape(); - const TensorShape& outputShape = GetTensorInfo(data.m_Outputs[0]).GetShape(); + const TensorShape& inShape0 = inputInfo0.GetShape(); + const TensorShape& inShape1 = inputInfo1.GetShape(); + const TensorShape& outShape = outputInfo.GetShape(); - const float* inData0 = GetInputTensorDataFloat(0, data); - const float* inData1 = GetInputTensorDataFloat(1, data); - uint8_t* outData = GetOutputTensorData<uint8_t>(0, data); + switch(inputInfo0.GetDataType()) + { + case armnn::DataType::QuantisedAsymm8: + { + QASymm8Decoder decodeIterator0(GetInputTensorDataU8(0, m_Data), + inputInfo0.GetQuantizationScale(), + inputInfo0.GetQuantizationOffset()); - ElementwiseFunction<Functor, float, uint8_t>(inShape0, - inShape1, - outputShape, - inData0, - inData1, - outData); + QASymm8Decoder decodeIterator1(GetInputTensorDataU8(1, m_Data), + inputInfo1.GetQuantizationScale(), + inputInfo1.GetQuantizationOffset()); -} - -template<typename ParentDescriptor, typename Functor> -void RefUint8ComparisonWorkload<ParentDescriptor, Functor>::ExecuteImpl(const char* debugString) const -{ - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString); - - auto data = BaseUint8ComparisonWorkload<ParentDescriptor>::GetData(); - const TensorShape& inputInfo0 = GetTensorInfo(data.m_Inputs[0]).GetShape(); - const TensorShape& inputInfo1 = GetTensorInfo(data.m_Inputs[1]).GetShape(); - const TensorShape& outputShape = GetTensorInfo(data.m_Outputs[0]).GetShape(); + BooleanEncoder encodeIterator0(GetOutputTensorDataU8(0, m_Data)); - const uint8_t* inData0 = GetInputTensorData<uint8_t>(0, data); - const uint8_t* inData1 = GetInputTensorData<uint8_t>(1, data); - uint8_t* outData = GetOutputTensorData<uint8_t>(0, data); + ElementwiseFunction<Functor, Decoder, ComparisonEncoder>(inShape0, + inShape1, + outShape, + decodeIterator0, + decodeIterator1, + encodeIterator0); + break; + } + case armnn::DataType::Float32: + { + FloatDecoder decodeIterator0(GetInputTensorDataFloat(0, m_Data)); + FloatDecoder decodeIterator1(GetInputTensorDataFloat(1, m_Data)); + BooleanEncoder encodeIterator0(GetOutputTensorDataU8(0, m_Data)); - ElementwiseFunction<Functor, uint8_t, uint8_t>(inputInfo0, - inputInfo1, - outputShape, - inData0, - inData1, - outData); + ElementwiseFunction<Functor, Decoder, ComparisonEncoder>(inShape0, + inShape1, + outShape, + decodeIterator0, + decodeIterator1, + encodeIterator0); + break; + } + default: + BOOST_ASSERT_MSG(false, "RefComparisonWorkload: Not supported Data Type!"); + break; + } } } -template class armnn::RefFloat32ComparisonWorkload<armnn::EqualQueueDescriptor, std::equal_to<float>>; -template class armnn::RefUint8ComparisonWorkload<armnn::EqualQueueDescriptor, std::equal_to<uint8_t>>; +template class armnn::RefComparisonWorkload<std::equal_to<float>, + armnn::EqualQueueDescriptor, + armnn::StringMapping::RefEqualWorkload_Execute>; -template class armnn::RefFloat32ComparisonWorkload<armnn::GreaterQueueDescriptor, std::greater<float>>; -template class armnn::RefUint8ComparisonWorkload<armnn::GreaterQueueDescriptor, std::greater<uint8_t>>; +template class armnn::RefComparisonWorkload<std::greater<float>, + armnn::GreaterQueueDescriptor, + armnn::StringMapping::RefGreaterWorkload_Execute>; |