From 2e6dc3a1c5d47825535db7993ba77eb1596ae99b Mon Sep 17 00:00:00 2001 From: Sadik Armagan Date: Wed, 3 Apr 2019 17:48:18 +0100 Subject: IVGCVSW-2861 Refactor the Reference Elementwise workload * Refactor Reference Comparison workload * Removed templating based on the DataType * Implemented BaseIterator to do decode/encode Change-Id: I18f299f47ee23772f90152c1146b42f07465e105 Signed-off-by: Sadik Armagan Signed-off-by: Kevin May --- src/backends/reference/RefWorkloadFactory.cpp | 16 +-- .../reference/test/RefCreateWorkloadTests.cpp | 16 +-- src/backends/reference/workloads/BaseIterator.hpp | 155 +++++++++++++++++++++ src/backends/reference/workloads/Broadcast.hpp | 24 +++- src/backends/reference/workloads/CMakeLists.txt | 1 + .../reference/workloads/ElementwiseFunction.cpp | 34 ++--- .../reference/workloads/ElementwiseFunction.hpp | 9 +- .../reference/workloads/RefComparisonWorkload.cpp | 91 ++++++------ .../reference/workloads/RefComparisonWorkload.hpp | 76 ++-------- .../reference/workloads/RefElementwiseWorkload.cpp | 115 ++++++--------- .../reference/workloads/RefElementwiseWorkload.hpp | 69 ++------- 11 files changed, 329 insertions(+), 277 deletions(-) create mode 100644 src/backends/reference/workloads/BaseIterator.hpp (limited to 'src/backends/reference') diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 619c14e007..8ea923d599 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -174,13 +174,13 @@ std::unique_ptr RefWorkloadFactory::CreateNormalization( std::unique_ptr RefWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload(descriptor, info); + return std::make_unique(descriptor, info); } std::unique_ptr RefWorkloadFactory::CreateMultiplication( const MultiplicationQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload(descriptor, info); + return std::make_unique(descriptor, info); } std::unique_ptr RefWorkloadFactory::CreateBatchNormalization( @@ -266,19 +266,19 @@ std::unique_ptr RefWorkloadFactory::CreateConvertFp32ToFp16( std::unique_ptr RefWorkloadFactory::CreateDivision( const DivisionQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload(descriptor, info); + return std::make_unique(descriptor, info); } std::unique_ptr RefWorkloadFactory::CreateSubtraction( const SubtractionQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload(descriptor, info); + return std::make_unique(descriptor, info); } std::unique_ptr RefWorkloadFactory::CreateMaximum( const MaximumQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload(descriptor, info); + return std::make_unique(descriptor, info); } std::unique_ptr RefWorkloadFactory::CreateMean( @@ -290,7 +290,7 @@ std::unique_ptr RefWorkloadFactory::CreateMean( std::unique_ptr RefWorkloadFactory::CreateMinimum( const MinimumQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload(descriptor, info); + return std::make_unique(descriptor, info); } std::unique_ptr RefWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor, @@ -302,7 +302,7 @@ std::unique_ptr RefWorkloadFactory::CreatePad(const PadQueueDescripto std::unique_ptr RefWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload(descriptor, info); + return std::make_unique(descriptor, info); } std::unique_ptr RefWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor, @@ -320,7 +320,7 @@ std::unique_ptr RefWorkloadFactory::CreateStridedSlice(const StridedS std::unique_ptr RefWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload(descriptor, info); + return std::make_unique(descriptor, info); } std::unique_ptr RefWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor, diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index 8621122925..09b0246895 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -82,7 +82,7 @@ static void RefCreateElementwiseWorkloadTest() BOOST_AUTO_TEST_CASE(CreateAdditionFloatWorkload) { - RefCreateElementwiseWorkloadTest(); @@ -90,7 +90,7 @@ BOOST_AUTO_TEST_CASE(CreateAdditionFloatWorkload) BOOST_AUTO_TEST_CASE(CreateAdditionUint8Workload) { - RefCreateElementwiseWorkloadTest(); @@ -98,7 +98,7 @@ BOOST_AUTO_TEST_CASE(CreateAdditionUint8Workload) BOOST_AUTO_TEST_CASE(CreateSubtractionFloatWorkload) { - RefCreateElementwiseWorkloadTest(); @@ -106,7 +106,7 @@ BOOST_AUTO_TEST_CASE(CreateSubtractionFloatWorkload) BOOST_AUTO_TEST_CASE(CreateSubtractionUint8Workload) { - RefCreateElementwiseWorkloadTest(); @@ -114,7 +114,7 @@ BOOST_AUTO_TEST_CASE(CreateSubtractionUint8Workload) BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkload) { - RefCreateElementwiseWorkloadTest(); @@ -122,7 +122,7 @@ BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkload) BOOST_AUTO_TEST_CASE(CreateMultiplicationUint8Workload) { - RefCreateElementwiseWorkloadTest(); @@ -130,7 +130,7 @@ BOOST_AUTO_TEST_CASE(CreateMultiplicationUint8Workload) BOOST_AUTO_TEST_CASE(CreateDivisionFloatWorkload) { - RefCreateElementwiseWorkloadTest(); @@ -138,7 +138,7 @@ BOOST_AUTO_TEST_CASE(CreateDivisionFloatWorkload) BOOST_AUTO_TEST_CASE(CreateDivisionUint8Workload) { - RefCreateElementwiseWorkloadTest(); diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp new file mode 100644 index 0000000000..cfa8ce7e91 --- /dev/null +++ b/src/backends/reference/workloads/BaseIterator.hpp @@ -0,0 +1,155 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include +#include + +namespace armnn +{ + +class BaseIterator +{ +public: + BaseIterator() {} + + virtual ~BaseIterator() {} + + virtual BaseIterator& operator++() = 0; + + virtual BaseIterator& operator+=(const unsigned int increment) = 0; + + virtual BaseIterator& operator-=(const unsigned int increment) = 0; +}; + +class Decoder : public BaseIterator +{ +public: + Decoder() : BaseIterator() {} + + virtual ~Decoder() {} + + virtual float Get() const = 0; +}; + +class Encoder : public BaseIterator +{ +public: + Encoder() : BaseIterator() {} + + virtual ~Encoder() {} + + virtual void Set(const float& right) = 0; +}; + +class ComparisonEncoder : public BaseIterator +{ +public: + ComparisonEncoder() : BaseIterator() {} + + virtual ~ComparisonEncoder() {} + + virtual void Set(bool right) = 0; +}; + +template +class TypedIterator : public Base +{ +public: + TypedIterator(T* data) + : m_Iterator(data) + {} + + TypedIterator& operator++() override + { + ++m_Iterator; + return *this; + } + + TypedIterator& operator+=(const unsigned int increment) override + { + m_Iterator += increment; + return *this; + } + + TypedIterator& operator-=(const unsigned int increment) override + { + m_Iterator -= increment; + return *this; + } + + T* m_Iterator; +}; + +class QASymm8Decoder : public TypedIterator +{ +public: + QASymm8Decoder(const uint8_t* data, const float scale, const int32_t offset) + : TypedIterator(data), m_Scale(scale), m_Offset(offset) {} + + float Get() const override + { + return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset); + } + +private: + const float m_Scale; + const int32_t m_Offset; +}; + +class FloatDecoder : public TypedIterator +{ +public: + FloatDecoder(const float* data) + : TypedIterator(data) {} + + float Get() const override + { + return *m_Iterator; + } +}; + +class FloatEncoder : public TypedIterator +{ +public: + FloatEncoder(float* data) + : TypedIterator(data) {} + + void Set(const float& right) override + { + *m_Iterator = right; + } +}; + +class QASymm8Encoder : public TypedIterator +{ +public: + QASymm8Encoder(uint8_t* data, const float scale, const int32_t offset) + : TypedIterator(data), m_Scale(scale), m_Offset(offset) {} + + void Set(const float& right) override + { + *m_Iterator = armnn::Quantize(right, m_Scale, m_Offset); + } + +private: + const float m_Scale; + const int32_t m_Offset; +}; + +class BooleanEncoder : public TypedIterator +{ +public: + BooleanEncoder(uint8_t* data) + : TypedIterator(data) {} + + void Set(bool right) override + { + *m_Iterator = right; + } +}; + +} //namespace armnn \ No newline at end of file diff --git a/src/backends/reference/workloads/Broadcast.hpp b/src/backends/reference/workloads/Broadcast.hpp index e92ed0598d..5bf6be8939 100644 --- a/src/backends/reference/workloads/Broadcast.hpp +++ b/src/backends/reference/workloads/Broadcast.hpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: MIT // +#include "BaseIterator.hpp" #include #include @@ -19,19 +20,23 @@ struct BroadcastLoop return static_cast(m_DimData.size()); } - template + template void Unroll(Func operationFunc, unsigned int dimension, - const T0* inData0, - const T1* inData1, - U* outData) + DecoderOp& inData0, + DecoderOp& inData1, + EncoderOp& outData) { if (dimension >= GetNumDimensions()) { - *outData = operationFunc(*inData0, *inData1); + outData.Set(operationFunc(inData0.Get(), inData1.Get())); return; } + unsigned int inData0Movement = 0; + unsigned int inData1Movement = 0; + unsigned int outDataMovement = 0; + for (unsigned int i = 0; i < m_DimData[dimension].m_DimSize; i++) { Unroll(operationFunc, dimension + 1, inData0, inData1, outData); @@ -39,7 +44,16 @@ struct BroadcastLoop inData0 += m_DimData[dimension].m_Stride1; inData1 += m_DimData[dimension].m_Stride2; outData += m_DimData[dimension].m_StrideOut; + + inData0Movement += m_DimData[dimension].m_Stride1; + inData1Movement += m_DimData[dimension].m_Stride2; + outDataMovement += m_DimData[dimension].m_StrideOut; } + + // move iterator back to the start + inData0 -= inData0Movement; + inData1 -= inData1Movement; + outData -= outDataMovement; } private: diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 4f5fbb554e..4ff2466e87 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -6,6 +6,7 @@ list(APPEND armnnRefBackendWorkloads_sources Activation.cpp Activation.hpp + BaseIterator.hpp BatchNormImpl.hpp BatchToSpaceNd.cpp BatchToSpaceNd.hpp diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp index c8c25ef9e9..934a86217a 100644 --- a/src/backends/reference/workloads/ElementwiseFunction.cpp +++ b/src/backends/reference/workloads/ElementwiseFunction.cpp @@ -13,26 +13,26 @@ namespace armnn { -template -ElementwiseFunction::ElementwiseFunction(const TensorShape& inShape0, - const TensorShape& inShape1, - const TensorShape& outShape, - const dataTypeInput* inData0, - const dataTypeInput* inData1, - dataTypeOutput* outData) +template +ElementwiseFunction::ElementwiseFunction(const TensorShape& inShape0, + const TensorShape& inShape1, + const TensorShape& outShape, + DecoderOp& inData0, + DecoderOp& inData1, + EncoderOp& outData) { BroadcastLoop(inShape0, inShape1, outShape).Unroll(Functor(), 0, inData0, inData1, outData); } } //namespace armnn -template struct armnn::ElementwiseFunction, float, float>; -template struct armnn::ElementwiseFunction, float, float>; -template struct armnn::ElementwiseFunction, float, float>; -template struct armnn::ElementwiseFunction, float, float>; -template struct armnn::ElementwiseFunction, float, float>; -template struct armnn::ElementwiseFunction, float, float>; -template struct armnn::ElementwiseFunction, float ,uint8_t>; -template struct armnn::ElementwiseFunction, uint8_t, uint8_t>; -template struct armnn::ElementwiseFunction, float, uint8_t>; -template struct armnn::ElementwiseFunction, uint8_t, uint8_t>; +template struct armnn::ElementwiseFunction, armnn::Decoder, armnn::Encoder>; +template struct armnn::ElementwiseFunction, armnn::Decoder, armnn::Encoder>; +template struct armnn::ElementwiseFunction, armnn::Decoder, armnn::Encoder>; +template struct armnn::ElementwiseFunction, armnn::Decoder, armnn::Encoder>; +template struct armnn::ElementwiseFunction, armnn::Decoder, armnn::Encoder>; +template struct armnn::ElementwiseFunction, armnn::Decoder, armnn::Encoder>; + +template struct armnn::ElementwiseFunction, armnn::Decoder, armnn::ComparisonEncoder>; +template struct armnn::ElementwiseFunction, armnn::Decoder, armnn::ComparisonEncoder>; + diff --git a/src/backends/reference/workloads/ElementwiseFunction.hpp b/src/backends/reference/workloads/ElementwiseFunction.hpp index 8099f3279a..9eb003d5f9 100644 --- a/src/backends/reference/workloads/ElementwiseFunction.hpp +++ b/src/backends/reference/workloads/ElementwiseFunction.hpp @@ -5,20 +5,21 @@ #pragma once +#include "BaseIterator.hpp" #include namespace armnn { -template +template struct ElementwiseFunction { ElementwiseFunction(const TensorShape& inShape0, const TensorShape& inShape1, const TensorShape& outShape, - const dataTypeInput* inData0, - const dataTypeInput* inData1, - dataTypeOutput* outData); + DecoderOp& inData0, + DecoderOp& inData1, + EncoderOp& outData); }; } //namespace armnn diff --git a/src/backends/reference/workloads/RefComparisonWorkload.cpp b/src/backends/reference/workloads/RefComparisonWorkload.cpp index fe517ff51a..bb8bb04ad3 100644 --- a/src/backends/reference/workloads/RefComparisonWorkload.cpp +++ b/src/backends/reference/workloads/RefComparisonWorkload.cpp @@ -11,55 +11,66 @@ namespace armnn { -template -void RefFloat32ComparisonWorkload::ExecuteImpl(const char* debugString) const +template +void RefComparisonWorkload::Execute() const { - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString); + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, StringMapping::Instance().Get(DebugString)); + const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]); + const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); - auto data = BaseFloat32ComparisonWorkload::GetData(); - const TensorShape& inShape0 = GetTensorInfo(data.m_Inputs[0]).GetShape(); - const TensorShape& inShape1 = GetTensorInfo(data.m_Inputs[1]).GetShape(); - const TensorShape& outputShape = GetTensorInfo(data.m_Outputs[0]).GetShape(); + const TensorShape& inShape0 = inputInfo0.GetShape(); + const TensorShape& inShape1 = inputInfo1.GetShape(); + const TensorShape& outShape = outputInfo.GetShape(); - const float* inData0 = GetInputTensorDataFloat(0, data); - const float* inData1 = GetInputTensorDataFloat(1, data); - uint8_t* outData = GetOutputTensorData(0, data); + switch(inputInfo0.GetDataType()) + { + case armnn::DataType::QuantisedAsymm8: + { + QASymm8Decoder decodeIterator0(GetInputTensorDataU8(0, m_Data), + inputInfo0.GetQuantizationScale(), + inputInfo0.GetQuantizationOffset()); - ElementwiseFunction(inShape0, - inShape1, - outputShape, - inData0, - inData1, - outData); + QASymm8Decoder decodeIterator1(GetInputTensorDataU8(1, m_Data), + inputInfo1.GetQuantizationScale(), + inputInfo1.GetQuantizationOffset()); -} - -template -void RefUint8ComparisonWorkload::ExecuteImpl(const char* debugString) const -{ - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString); - - auto data = BaseUint8ComparisonWorkload::GetData(); - const TensorShape& inputInfo0 = GetTensorInfo(data.m_Inputs[0]).GetShape(); - const TensorShape& inputInfo1 = GetTensorInfo(data.m_Inputs[1]).GetShape(); - const TensorShape& outputShape = GetTensorInfo(data.m_Outputs[0]).GetShape(); + BooleanEncoder encodeIterator0(GetOutputTensorDataU8(0, m_Data)); - const uint8_t* inData0 = GetInputTensorData(0, data); - const uint8_t* inData1 = GetInputTensorData(1, data); - uint8_t* outData = GetOutputTensorData(0, data); + ElementwiseFunction(inShape0, + inShape1, + outShape, + decodeIterator0, + decodeIterator1, + encodeIterator0); + break; + } + case armnn::DataType::Float32: + { + FloatDecoder decodeIterator0(GetInputTensorDataFloat(0, m_Data)); + FloatDecoder decodeIterator1(GetInputTensorDataFloat(1, m_Data)); + BooleanEncoder encodeIterator0(GetOutputTensorDataU8(0, m_Data)); - ElementwiseFunction(inputInfo0, - inputInfo1, - outputShape, - inData0, - inData1, - outData); + ElementwiseFunction(inShape0, + inShape1, + outShape, + decodeIterator0, + decodeIterator1, + encodeIterator0); + break; + } + default: + BOOST_ASSERT_MSG(false, "RefComparisonWorkload: Not supported Data Type!"); + break; + } } } -template class armnn::RefFloat32ComparisonWorkload>; -template class armnn::RefUint8ComparisonWorkload>; +template class armnn::RefComparisonWorkload, + armnn::EqualQueueDescriptor, + armnn::StringMapping::RefEqualWorkload_Execute>; -template class armnn::RefFloat32ComparisonWorkload>; -template class armnn::RefUint8ComparisonWorkload>; +template class armnn::RefComparisonWorkload, + armnn::GreaterQueueDescriptor, + armnn::StringMapping::RefGreaterWorkload_Execute>; diff --git a/src/backends/reference/workloads/RefComparisonWorkload.hpp b/src/backends/reference/workloads/RefComparisonWorkload.hpp index 524d20625a..cfc2dcf2aa 100644 --- a/src/backends/reference/workloads/RefComparisonWorkload.hpp +++ b/src/backends/reference/workloads/RefComparisonWorkload.hpp @@ -13,80 +13,24 @@ namespace armnn { -template -class RefComparisonWorkload -{ - // Needs specialization. The default is empty on purpose. -}; - -template -class RefFloat32ComparisonWorkload : public BaseFloat32ComparisonWorkload -{ -public: - using BaseFloat32ComparisonWorkload::BaseFloat32ComparisonWorkload; - void ExecuteImpl(const char * debugString) const; -}; - -template -class RefComparisonWorkload - : public RefFloat32ComparisonWorkload -{ -public: - using RefFloat32ComparisonWorkload::RefFloat32ComparisonWorkload; - - virtual void Execute() const override - { - using Parent = RefFloat32ComparisonWorkload; - Parent::ExecuteImpl(StringMapping::Instance().Get(DebugString)); - } -}; - -template -class RefUint8ComparisonWorkload : public BaseUint8ComparisonWorkload -{ -public: - using BaseUint8ComparisonWorkload::BaseUint8ComparisonWorkload; - void ExecuteImpl(const char * debugString) const; -}; - template -class RefComparisonWorkload - : public RefUint8ComparisonWorkload +class RefComparisonWorkload : public BaseWorkload { public: - using RefUint8ComparisonWorkload::RefUint8ComparisonWorkload; + using BaseWorkload::m_Data; + using BaseWorkload::BaseWorkload; - virtual void Execute() const override - { - using Parent = RefUint8ComparisonWorkload; - Parent::ExecuteImpl(StringMapping::Instance().Get(DebugString)); - } + void Execute() const override; }; -using RefEqualFloat32Workload = +using RefEqualWorkload = RefComparisonWorkload, - DataType::Float32, - EqualQueueDescriptor, - StringMapping::RefEqualWorkload_Execute>; + EqualQueueDescriptor, + StringMapping::RefEqualWorkload_Execute>; -using RefEqualUint8Workload = - RefComparisonWorkload, - DataType::QuantisedAsymm8, - EqualQueueDescriptor, - StringMapping::RefEqualWorkload_Execute>; -using RefGreaterFloat32Workload = +using RefGreaterWorkload = RefComparisonWorkload, - DataType::Float32, - GreaterQueueDescriptor, - StringMapping::RefGreaterWorkload_Execute>; - -using RefGreaterUint8Workload = - RefComparisonWorkload, - DataType::QuantisedAsymm8, - GreaterQueueDescriptor, - StringMapping::RefGreaterWorkload_Execute>; + GreaterQueueDescriptor, + StringMapping::RefGreaterWorkload_Execute>; } // armnn diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp index 356d7a0c16..6e6e1d5f21 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp @@ -14,14 +14,10 @@ namespace armnn { -template -void RefElementwiseWorkload::Execute() const +template +void RefElementwiseWorkload::Execute() const { ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, StringMapping::Instance().Get(DebugString)); - const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]); const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); @@ -30,32 +26,46 @@ void RefElementwiseWorkload::E const TensorShape& inShape1 = inputInfo1.GetShape(); const TensorShape& outShape = outputInfo.GetShape(); - switch(DataType) + switch(inputInfo0.GetDataType()) { case armnn::DataType::QuantisedAsymm8: { - std::vector results(outputInfo.GetNumElements()); - ElementwiseFunction(inShape0, - inShape1, - outShape, - Dequantize(GetInputTensorDataU8(0, m_Data), inputInfo0).data(), - Dequantize(GetInputTensorDataU8(1, m_Data), inputInfo1).data(), - results.data()); - Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo); + QASymm8Decoder decodeIterator0(GetInputTensorDataU8(0, m_Data), + inputInfo0.GetQuantizationScale(), + inputInfo0.GetQuantizationOffset()); + + QASymm8Decoder decodeIterator1(GetInputTensorDataU8(1, m_Data), + inputInfo1.GetQuantizationScale(), + inputInfo1.GetQuantizationOffset()); + + QASymm8Encoder encodeIterator0(GetOutputTensorDataU8(0, m_Data), + outputInfo.GetQuantizationScale(), + outputInfo.GetQuantizationOffset()); + + ElementwiseFunction(inShape0, + inShape1, + outShape, + decodeIterator0, + decodeIterator1, + encodeIterator0); break; } case armnn::DataType::Float32: { - ElementwiseFunction(inShape0, - inShape1, - outShape, - GetInputTensorDataFloat(0, m_Data), - GetInputTensorDataFloat(1, m_Data), - GetOutputTensorDataFloat(0, m_Data)); + FloatDecoder decodeIterator0(GetInputTensorDataFloat(0, m_Data)); + FloatDecoder decodeIterator1(GetInputTensorDataFloat(1, m_Data)); + FloatEncoder encodeIterator0(GetOutputTensorDataFloat(0, m_Data)); + + ElementwiseFunction(inShape0, + inShape1, + outShape, + decodeIterator0, + decodeIterator1, + encodeIterator0); break; } default: - BOOST_ASSERT_MSG(false, "Unknown Data Type!"); + BOOST_ASSERT_MSG(false, "RefElementwiseWorkload: Not supported Data Type!"); break; } } @@ -63,62 +73,25 @@ void RefElementwiseWorkload::E } template class armnn::RefElementwiseWorkload, - armnn::DataType::Float32, - armnn::AdditionQueueDescriptor, - armnn::StringMapping::RefAdditionWorkload_Execute>; - -template class armnn::RefElementwiseWorkload, - armnn::DataType::QuantisedAsymm8, - armnn::AdditionQueueDescriptor, - armnn::StringMapping::RefAdditionWorkload_Execute>; + armnn::AdditionQueueDescriptor, + armnn::StringMapping::RefAdditionWorkload_Execute>; template class armnn::RefElementwiseWorkload, - armnn::DataType::Float32, - armnn::SubtractionQueueDescriptor, - armnn::StringMapping::RefSubtractionWorkload_Execute>; - -template class armnn::RefElementwiseWorkload, - armnn::DataType::QuantisedAsymm8, - armnn::SubtractionQueueDescriptor, - armnn::StringMapping::RefSubtractionWorkload_Execute>; - -template class armnn::RefElementwiseWorkload, - armnn::DataType::Float32, - armnn::MultiplicationQueueDescriptor, - armnn::StringMapping::RefMultiplicationWorkload_Execute>; + armnn::SubtractionQueueDescriptor, + armnn::StringMapping::RefSubtractionWorkload_Execute>; template class armnn::RefElementwiseWorkload, - armnn::DataType::QuantisedAsymm8, - armnn::MultiplicationQueueDescriptor, - armnn::StringMapping::RefMultiplicationWorkload_Execute>; + armnn::MultiplicationQueueDescriptor, + armnn::StringMapping::RefMultiplicationWorkload_Execute>; template class armnn::RefElementwiseWorkload, - armnn::DataType::Float32, - armnn::DivisionQueueDescriptor, - armnn::StringMapping::RefDivisionWorkload_Execute>; - -template class armnn::RefElementwiseWorkload, - armnn::DataType::QuantisedAsymm8, - armnn::DivisionQueueDescriptor, - armnn::StringMapping::RefDivisionWorkload_Execute>; - -template class armnn::RefElementwiseWorkload, - armnn::DataType::Float32, - armnn::MaximumQueueDescriptor, - armnn::StringMapping::RefMaximumWorkload_Execute>; + armnn::DivisionQueueDescriptor, + armnn::StringMapping::RefDivisionWorkload_Execute>; template class armnn::RefElementwiseWorkload, - armnn::DataType::QuantisedAsymm8, - armnn::MaximumQueueDescriptor, - armnn::StringMapping::RefMaximumWorkload_Execute>; - - -template class armnn::RefElementwiseWorkload, - armnn::DataType::Float32, - armnn::MinimumQueueDescriptor, - armnn::StringMapping::RefMinimumWorkload_Execute>; + armnn::MaximumQueueDescriptor, + armnn::StringMapping::RefMaximumWorkload_Execute>; template class armnn::RefElementwiseWorkload, - armnn::DataType::QuantisedAsymm8, - armnn::MinimumQueueDescriptor, - armnn::StringMapping::RefMinimumWorkload_Execute>; + armnn::MinimumQueueDescriptor, + armnn::StringMapping::RefMinimumWorkload_Execute>; \ No newline at end of file diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp index 371904977a..81af19627e 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp @@ -15,90 +15,43 @@ namespace armnn { -template -class RefElementwiseWorkload - : public TypedWorkload +template +class RefElementwiseWorkload : public BaseWorkload { public: - - using TypedWorkload::m_Data; - using TypedWorkload::TypedWorkload; + using BaseWorkload::m_Data; + using BaseWorkload::BaseWorkload; void Execute() const override; }; -using RefAdditionFloat32Workload = - RefElementwiseWorkload, - DataType::Float32, - AdditionQueueDescriptor, - StringMapping::RefAdditionWorkload_Execute>; - -using RefAdditionUint8Workload = +using RefAdditionWorkload = RefElementwiseWorkload, - DataType::QuantisedAsymm8, AdditionQueueDescriptor, StringMapping::RefAdditionWorkload_Execute>; -using RefSubtractionFloat32Workload = - RefElementwiseWorkload, - DataType::Float32, - SubtractionQueueDescriptor, - StringMapping::RefSubtractionWorkload_Execute>; - -using RefSubtractionUint8Workload = +using RefSubtractionWorkload = RefElementwiseWorkload, - DataType::QuantisedAsymm8, SubtractionQueueDescriptor, StringMapping::RefSubtractionWorkload_Execute>; -using RefMultiplicationFloat32Workload = - RefElementwiseWorkload, - DataType::Float32, - MultiplicationQueueDescriptor, - StringMapping::RefMultiplicationWorkload_Execute>; - -using RefMultiplicationUint8Workload = +using RefMultiplicationWorkload = RefElementwiseWorkload, - DataType::QuantisedAsymm8, MultiplicationQueueDescriptor, StringMapping::RefMultiplicationWorkload_Execute>; -using RefDivisionFloat32Workload = - RefElementwiseWorkload, - DataType::Float32, - DivisionQueueDescriptor, - StringMapping::RefDivisionWorkload_Execute>; - -using RefDivisionUint8Workload = +using RefDivisionWorkload = RefElementwiseWorkload, - DataType::QuantisedAsymm8, DivisionQueueDescriptor, StringMapping::RefDivisionWorkload_Execute>; -using RefMaximumFloat32Workload = - RefElementwiseWorkload, - DataType::Float32, - MaximumQueueDescriptor, - StringMapping::RefMaximumWorkload_Execute>; - -using RefMaximumUint8Workload = +using RefMaximumWorkload = RefElementwiseWorkload, - DataType::QuantisedAsymm8, MaximumQueueDescriptor, StringMapping::RefMaximumWorkload_Execute>; -using RefMinimumFloat32Workload = - RefElementwiseWorkload, - DataType::Float32, - MinimumQueueDescriptor, - StringMapping::RefMinimumWorkload_Execute>; - -using RefMinimumUint8Workload = - RefElementwiseWorkload, - DataType::QuantisedAsymm8, +using RefMinimumWorkload = + RefElementwiseWorkload, MinimumQueueDescriptor, StringMapping::RefMinimumWorkload_Execute>; } // armnn -- cgit v1.2.1