diff options
Diffstat (limited to 'src/backends')
13 files changed, 438 insertions, 285 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index b850a65acf..1360ac5d0c 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -491,13 +491,29 @@ void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateNumInputs(workloadInfo, "AdditionQueueDescriptor", 2); ValidateNumOutputs(workloadInfo, "AdditionQueueDescriptor", 1); + std::vector<DataType> supportedTypes = { + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], + supportedTypes, + "AdditionQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[1], + supportedTypes, + "AdditionQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], + supportedTypes, + "AdditionQueueDescriptor"); + ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], workloadInfo.m_OutputTensorInfos[0], "AdditionQueueDescriptor", "first input", "second input"); - } //--------------------------------------------------------------- @@ -506,6 +522,23 @@ void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c ValidateNumInputs(workloadInfo, "MultiplicationQueueDescriptor", 2); ValidateNumOutputs(workloadInfo, "MultiplicationQueueDescriptor", 1); + std::vector<DataType> supportedTypes = { + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], + supportedTypes, + "MultiplicationQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[1], + supportedTypes, + "MultiplicationQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], + supportedTypes, + "MultiplicationQueueDescriptor"); + ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], workloadInfo.m_OutputTensorInfos[0], @@ -857,6 +890,23 @@ void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateNumInputs(workloadInfo, "DivisionQueueDescriptor", 2); ValidateNumOutputs(workloadInfo, "DivisionQueueDescriptor", 1); + std::vector<DataType> supportedTypes = { + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], + supportedTypes, + "DivisionQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[1], + supportedTypes, + "DivisionQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], + supportedTypes, + "DivisionQueueDescriptor"); + ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], workloadInfo.m_OutputTensorInfos[0], @@ -870,6 +920,23 @@ void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons ValidateNumInputs(workloadInfo, "SubtractionQueueDescriptor", 2); ValidateNumOutputs(workloadInfo, "SubtractionQueueDescriptor", 1); + std::vector<DataType> supportedTypes = { + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], + supportedTypes, + "SubtractionQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[1], + supportedTypes, + "SubtractionQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], + supportedTypes, + "SubtractionQueueDescriptor"); + ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], workloadInfo.m_OutputTensorInfos[0], @@ -883,6 +950,23 @@ void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateNumInputs(workloadInfo, "MaximumQueueDescriptor", 2); ValidateNumOutputs(workloadInfo, "MaximumQueueDescriptor", 1); + std::vector<DataType> supportedTypes = { + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], + supportedTypes, + "MaximumQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[1], + supportedTypes, + "MaximumQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], + supportedTypes, + "MaximumQueueDescriptor"); + ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], workloadInfo.m_OutputTensorInfos[0], @@ -1008,6 +1092,23 @@ void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateNumInputs(workloadInfo, "MinimumQueueDescriptor", 2); ValidateNumOutputs(workloadInfo, "MinimumQueueDescriptor", 1); + std::vector<DataType> supportedTypes = { + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], + supportedTypes, + "MinimumQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[1], + supportedTypes, + "MinimumQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], + supportedTypes, + "MinimumQueueDescriptor"); + ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], workloadInfo.m_OutputTensorInfos[0], diff --git a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp index 3664d56c28..d37cc74c66 100644 --- a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp +++ b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp @@ -312,17 +312,17 @@ BOOST_AUTO_TEST_CASE(AdditionQueueDescriptor_Validate_InputNumbers) AddOutputToWorkload(invalidData, invalidInfo, outputTensorInfo, nullptr); // Too few inputs. - BOOST_CHECK_THROW(RefAdditionFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException); + BOOST_CHECK_THROW(RefAdditionWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException); AddInputToWorkload(invalidData, invalidInfo, input2TensorInfo, nullptr); // Correct. - BOOST_CHECK_NO_THROW(RefAdditionFloat32Workload(invalidData, invalidInfo)); + BOOST_CHECK_NO_THROW(RefAdditionWorkload(invalidData, invalidInfo)); AddInputToWorkload(invalidData, invalidInfo, input3TensorInfo, nullptr); // Too many inputs. - BOOST_CHECK_THROW(RefAdditionFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException); + BOOST_CHECK_THROW(RefAdditionWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException); } BOOST_AUTO_TEST_CASE(AdditionQueueDescriptor_Validate_InputShapes) @@ -347,7 +347,7 @@ BOOST_AUTO_TEST_CASE(AdditionQueueDescriptor_Validate_InputShapes) AddInputToWorkload(invalidData, invalidInfo, input2TensorInfo, nullptr); AddOutputToWorkload(invalidData, invalidInfo, outputTensorInfo, nullptr); - BOOST_CHECK_THROW(RefAdditionFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException); + BOOST_CHECK_THROW(RefAdditionWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException); } // Output size not compatible with input sizes. @@ -364,7 +364,7 @@ BOOST_AUTO_TEST_CASE(AdditionQueueDescriptor_Validate_InputShapes) AddOutputToWorkload(invalidData, invalidInfo, outputTensorInfo, nullptr); // Output differs. - BOOST_CHECK_THROW(RefAdditionFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException); + BOOST_CHECK_THROW(RefAdditionWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException); } } @@ -399,7 +399,7 @@ BOOST_AUTO_TEST_CASE(MultiplicationQueueDescriptor_Validate_InputTensorDimension AddInputToWorkload(invalidData, invalidInfo, input0TensorInfo, nullptr); AddInputToWorkload(invalidData, invalidInfo, input1TensorInfo, nullptr); - BOOST_CHECK_THROW(RefMultiplicationFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException); + BOOST_CHECK_THROW(RefMultiplicationWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException); } // Checks dimension consistency for input and output tensors. @@ -424,7 +424,7 @@ BOOST_AUTO_TEST_CASE(MultiplicationQueueDescriptor_Validate_InputTensorDimension AddInputToWorkload(invalidData, invalidInfo, input0TensorInfo, nullptr); AddInputToWorkload(invalidData, invalidInfo, input1TensorInfo, nullptr); - BOOST_CHECK_THROW(RefMultiplicationFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException); + BOOST_CHECK_THROW(RefMultiplicationWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException); } } diff --git a/src/backends/reference/RefWorkloadFactory.cpp b/src/backends/reference/RefWorkloadFactory.cpp index 619c14e007..8ea923d599 100644 --- a/src/backends/reference/RefWorkloadFactory.cpp +++ b/src/backends/reference/RefWorkloadFactory.cpp @@ -174,13 +174,13 @@ std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateNormalization( std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload<RefAdditionFloat32Workload, RefAdditionUint8Workload>(descriptor, info); + return std::make_unique<RefAdditionWorkload>(descriptor, info); } std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMultiplication( const MultiplicationQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload<RefMultiplicationFloat32Workload, RefMultiplicationUint8Workload>(descriptor, info); + return std::make_unique<RefMultiplicationWorkload>(descriptor, info); } std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateBatchNormalization( @@ -266,19 +266,19 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateConvertFp32ToFp16( std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateDivision( const DivisionQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload<RefDivisionFloat32Workload, RefDivisionUint8Workload>(descriptor, info); + return std::make_unique<RefDivisionWorkload>(descriptor, info); } std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateSubtraction( const SubtractionQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload<RefSubtractionFloat32Workload, RefSubtractionUint8Workload>(descriptor, info); + return std::make_unique<RefSubtractionWorkload>(descriptor, info); } std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMaximum( const MaximumQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload<RefMaximumFloat32Workload, RefMaximumUint8Workload>(descriptor, info); + return std::make_unique<RefMaximumWorkload>(descriptor, info); } std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMean( @@ -290,7 +290,7 @@ std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMean( std::unique_ptr<armnn::IWorkload> RefWorkloadFactory::CreateMinimum( const MinimumQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload<RefMinimumFloat32Workload, RefMinimumUint8Workload>(descriptor, info); + return std::make_unique<RefMinimumWorkload>(descriptor, info); } std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor, @@ -302,7 +302,7 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreatePad(const PadQueueDescripto std::unique_ptr<IWorkload> RefWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload<RefEqualFloat32Workload, RefEqualUint8Workload>(descriptor, info); + return std::make_unique<RefEqualWorkload>(descriptor, info); } std::unique_ptr<IWorkload> RefWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor, @@ -320,7 +320,7 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateStridedSlice(const StridedS std::unique_ptr<IWorkload> RefWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload<RefGreaterFloat32Workload, RefGreaterUint8Workload>(descriptor, info); + return std::make_unique<RefGreaterWorkload>(descriptor, info); } std::unique_ptr<IWorkload> RefWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor, diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index 8621122925..09b0246895 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -82,7 +82,7 @@ static void RefCreateElementwiseWorkloadTest() BOOST_AUTO_TEST_CASE(CreateAdditionFloatWorkload) { - RefCreateElementwiseWorkloadTest<RefAdditionFloat32Workload, + RefCreateElementwiseWorkloadTest<RefAdditionWorkload, AdditionQueueDescriptor, AdditionLayer, armnn::DataType::Float32>(); @@ -90,7 +90,7 @@ BOOST_AUTO_TEST_CASE(CreateAdditionFloatWorkload) BOOST_AUTO_TEST_CASE(CreateAdditionUint8Workload) { - RefCreateElementwiseWorkloadTest<RefAdditionUint8Workload, + RefCreateElementwiseWorkloadTest<RefAdditionWorkload, AdditionQueueDescriptor, AdditionLayer, armnn::DataType::QuantisedAsymm8>(); @@ -98,7 +98,7 @@ BOOST_AUTO_TEST_CASE(CreateAdditionUint8Workload) BOOST_AUTO_TEST_CASE(CreateSubtractionFloatWorkload) { - RefCreateElementwiseWorkloadTest<RefSubtractionFloat32Workload, + RefCreateElementwiseWorkloadTest<RefSubtractionWorkload, SubtractionQueueDescriptor, SubtractionLayer, armnn::DataType::Float32>(); @@ -106,7 +106,7 @@ BOOST_AUTO_TEST_CASE(CreateSubtractionFloatWorkload) BOOST_AUTO_TEST_CASE(CreateSubtractionUint8Workload) { - RefCreateElementwiseWorkloadTest<RefSubtractionUint8Workload, + RefCreateElementwiseWorkloadTest<RefSubtractionWorkload, SubtractionQueueDescriptor, SubtractionLayer, armnn::DataType::QuantisedAsymm8>(); @@ -114,7 +114,7 @@ BOOST_AUTO_TEST_CASE(CreateSubtractionUint8Workload) BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkload) { - RefCreateElementwiseWorkloadTest<RefMultiplicationFloat32Workload, + RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload, MultiplicationQueueDescriptor, MultiplicationLayer, armnn::DataType::Float32>(); @@ -122,7 +122,7 @@ BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkload) BOOST_AUTO_TEST_CASE(CreateMultiplicationUint8Workload) { - RefCreateElementwiseWorkloadTest<RefMultiplicationUint8Workload, + RefCreateElementwiseWorkloadTest<RefMultiplicationWorkload, MultiplicationQueueDescriptor, MultiplicationLayer, armnn::DataType::QuantisedAsymm8>(); @@ -130,7 +130,7 @@ BOOST_AUTO_TEST_CASE(CreateMultiplicationUint8Workload) BOOST_AUTO_TEST_CASE(CreateDivisionFloatWorkload) { - RefCreateElementwiseWorkloadTest<RefDivisionFloat32Workload, + RefCreateElementwiseWorkloadTest<RefDivisionWorkload, DivisionQueueDescriptor, DivisionLayer, armnn::DataType::Float32>(); @@ -138,7 +138,7 @@ BOOST_AUTO_TEST_CASE(CreateDivisionFloatWorkload) BOOST_AUTO_TEST_CASE(CreateDivisionUint8Workload) { - RefCreateElementwiseWorkloadTest<RefDivisionUint8Workload, + RefCreateElementwiseWorkloadTest<RefDivisionWorkload, DivisionQueueDescriptor, DivisionLayer, armnn::DataType::QuantisedAsymm8>(); diff --git a/src/backends/reference/workloads/BaseIterator.hpp b/src/backends/reference/workloads/BaseIterator.hpp new file mode 100644 index 0000000000..cfa8ce7e91 --- /dev/null +++ b/src/backends/reference/workloads/BaseIterator.hpp @@ -0,0 +1,155 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <armnn/ArmNN.hpp> +#include <TypeUtils.hpp> + +namespace armnn +{ + +class BaseIterator +{ +public: + BaseIterator() {} + + virtual ~BaseIterator() {} + + virtual BaseIterator& operator++() = 0; + + virtual BaseIterator& operator+=(const unsigned int increment) = 0; + + virtual BaseIterator& operator-=(const unsigned int increment) = 0; +}; + +class Decoder : public BaseIterator +{ +public: + Decoder() : BaseIterator() {} + + virtual ~Decoder() {} + + virtual float Get() const = 0; +}; + +class Encoder : public BaseIterator +{ +public: + Encoder() : BaseIterator() {} + + virtual ~Encoder() {} + + virtual void Set(const float& right) = 0; +}; + +class ComparisonEncoder : public BaseIterator +{ +public: + ComparisonEncoder() : BaseIterator() {} + + virtual ~ComparisonEncoder() {} + + virtual void Set(bool right) = 0; +}; + +template<typename T, typename Base> +class TypedIterator : public Base +{ +public: + TypedIterator(T* data) + : m_Iterator(data) + {} + + TypedIterator& operator++() override + { + ++m_Iterator; + return *this; + } + + TypedIterator& operator+=(const unsigned int increment) override + { + m_Iterator += increment; + return *this; + } + + TypedIterator& operator-=(const unsigned int increment) override + { + m_Iterator -= increment; + return *this; + } + + T* m_Iterator; +}; + +class QASymm8Decoder : public TypedIterator<const uint8_t, Decoder> +{ +public: + QASymm8Decoder(const uint8_t* data, const float scale, const int32_t offset) + : TypedIterator(data), m_Scale(scale), m_Offset(offset) {} + + float Get() const override + { + return armnn::Dequantize(*m_Iterator, m_Scale, m_Offset); + } + +private: + const float m_Scale; + const int32_t m_Offset; +}; + +class FloatDecoder : public TypedIterator<const float, Decoder> +{ +public: + FloatDecoder(const float* data) + : TypedIterator(data) {} + + float Get() const override + { + return *m_Iterator; + } +}; + +class FloatEncoder : public TypedIterator<float, Encoder> +{ +public: + FloatEncoder(float* data) + : TypedIterator(data) {} + + void Set(const float& right) override + { + *m_Iterator = right; + } +}; + +class QASymm8Encoder : public TypedIterator<uint8_t, Encoder> +{ +public: + QASymm8Encoder(uint8_t* data, const float scale, const int32_t offset) + : TypedIterator(data), m_Scale(scale), m_Offset(offset) {} + + void Set(const float& right) override + { + *m_Iterator = armnn::Quantize<uint8_t>(right, m_Scale, m_Offset); + } + +private: + const float m_Scale; + const int32_t m_Offset; +}; + +class BooleanEncoder : public TypedIterator<uint8_t, ComparisonEncoder> +{ +public: + BooleanEncoder(uint8_t* data) + : TypedIterator(data) {} + + void Set(bool right) override + { + *m_Iterator = right; + } +}; + +} //namespace armnn
\ No newline at end of file diff --git a/src/backends/reference/workloads/Broadcast.hpp b/src/backends/reference/workloads/Broadcast.hpp index e92ed0598d..5bf6be8939 100644 --- a/src/backends/reference/workloads/Broadcast.hpp +++ b/src/backends/reference/workloads/Broadcast.hpp @@ -3,6 +3,7 @@ // SPDX-License-Identifier: MIT // +#include "BaseIterator.hpp" #include <armnn/Tensor.hpp> #include <functional> @@ -19,19 +20,23 @@ struct BroadcastLoop return static_cast<unsigned int>(m_DimData.size()); } - template <typename T0, typename T1, typename U, typename Func> + template <typename Func, typename DecoderOp, typename EncoderOp> void Unroll(Func operationFunc, unsigned int dimension, - const T0* inData0, - const T1* inData1, - U* outData) + DecoderOp& inData0, + DecoderOp& inData1, + EncoderOp& outData) { if (dimension >= GetNumDimensions()) { - *outData = operationFunc(*inData0, *inData1); + outData.Set(operationFunc(inData0.Get(), inData1.Get())); return; } + unsigned int inData0Movement = 0; + unsigned int inData1Movement = 0; + unsigned int outDataMovement = 0; + for (unsigned int i = 0; i < m_DimData[dimension].m_DimSize; i++) { Unroll(operationFunc, dimension + 1, inData0, inData1, outData); @@ -39,7 +44,16 @@ struct BroadcastLoop inData0 += m_DimData[dimension].m_Stride1; inData1 += m_DimData[dimension].m_Stride2; outData += m_DimData[dimension].m_StrideOut; + + inData0Movement += m_DimData[dimension].m_Stride1; + inData1Movement += m_DimData[dimension].m_Stride2; + outDataMovement += m_DimData[dimension].m_StrideOut; } + + // move iterator back to the start + inData0 -= inData0Movement; + inData1 -= inData1Movement; + outData -= outDataMovement; } private: diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 4f5fbb554e..4ff2466e87 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -6,6 +6,7 @@ list(APPEND armnnRefBackendWorkloads_sources Activation.cpp Activation.hpp + BaseIterator.hpp BatchNormImpl.hpp BatchToSpaceNd.cpp BatchToSpaceNd.hpp diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp index c8c25ef9e9..934a86217a 100644 --- a/src/backends/reference/workloads/ElementwiseFunction.cpp +++ b/src/backends/reference/workloads/ElementwiseFunction.cpp @@ -13,26 +13,26 @@ namespace armnn { -template <typename Functor, typename dataTypeInput, typename dataTypeOutput> -ElementwiseFunction<Functor, dataTypeInput, dataTypeOutput>::ElementwiseFunction(const TensorShape& inShape0, - const TensorShape& inShape1, - const TensorShape& outShape, - const dataTypeInput* inData0, - const dataTypeInput* inData1, - dataTypeOutput* outData) +template <typename Functor, typename DecoderOp, typename EncoderOp> +ElementwiseFunction<Functor, DecoderOp, EncoderOp>::ElementwiseFunction(const TensorShape& inShape0, + const TensorShape& inShape1, + const TensorShape& outShape, + DecoderOp& inData0, + DecoderOp& inData1, + EncoderOp& outData) { BroadcastLoop(inShape0, inShape1, outShape).Unroll(Functor(), 0, inData0, inData1, outData); } } //namespace armnn -template struct armnn::ElementwiseFunction<std::plus<float>, float, float>; -template struct armnn::ElementwiseFunction<std::minus<float>, float, float>; -template struct armnn::ElementwiseFunction<std::multiplies<float>, float, float>; -template struct armnn::ElementwiseFunction<std::divides<float>, float, float>; -template struct armnn::ElementwiseFunction<armnn::maximum<float>, float, float>; -template struct armnn::ElementwiseFunction<armnn::minimum<float>, float, float>; -template struct armnn::ElementwiseFunction<std::equal_to<float>, float ,uint8_t>; -template struct armnn::ElementwiseFunction<std::equal_to<uint8_t>, uint8_t, uint8_t>; -template struct armnn::ElementwiseFunction<std::greater<float>, float, uint8_t>; -template struct armnn::ElementwiseFunction<std::greater<uint8_t>, uint8_t, uint8_t>; +template struct armnn::ElementwiseFunction<std::plus<float>, armnn::Decoder, armnn::Encoder>; +template struct armnn::ElementwiseFunction<std::minus<float>, armnn::Decoder, armnn::Encoder>; +template struct armnn::ElementwiseFunction<std::multiplies<float>, armnn::Decoder, armnn::Encoder>; +template struct armnn::ElementwiseFunction<std::divides<float>, armnn::Decoder, armnn::Encoder>; +template struct armnn::ElementwiseFunction<armnn::maximum<float>, armnn::Decoder, armnn::Encoder>; +template struct armnn::ElementwiseFunction<armnn::minimum<float>, armnn::Decoder, armnn::Encoder>; + +template struct armnn::ElementwiseFunction<std::equal_to<float>, armnn::Decoder, armnn::ComparisonEncoder>; +template struct armnn::ElementwiseFunction<std::greater<float>, armnn::Decoder, armnn::ComparisonEncoder>; + diff --git a/src/backends/reference/workloads/ElementwiseFunction.hpp b/src/backends/reference/workloads/ElementwiseFunction.hpp index 8099f3279a..9eb003d5f9 100644 --- a/src/backends/reference/workloads/ElementwiseFunction.hpp +++ b/src/backends/reference/workloads/ElementwiseFunction.hpp @@ -5,20 +5,21 @@ #pragma once +#include "BaseIterator.hpp" #include <armnn/Tensor.hpp> namespace armnn { -template <typename Functor, typename dataTypeInput, typename dataTypeOutput> +template <typename Functor, typename DecoderOp, typename EncoderOp> struct ElementwiseFunction { ElementwiseFunction(const TensorShape& inShape0, const TensorShape& inShape1, const TensorShape& outShape, - const dataTypeInput* inData0, - const dataTypeInput* inData1, - dataTypeOutput* outData); + DecoderOp& inData0, + DecoderOp& inData1, + EncoderOp& outData); }; } //namespace armnn diff --git a/src/backends/reference/workloads/RefComparisonWorkload.cpp b/src/backends/reference/workloads/RefComparisonWorkload.cpp index fe517ff51a..bb8bb04ad3 100644 --- a/src/backends/reference/workloads/RefComparisonWorkload.cpp +++ b/src/backends/reference/workloads/RefComparisonWorkload.cpp @@ -11,55 +11,66 @@ namespace armnn { -template<typename ParentDescriptor, typename Functor> -void RefFloat32ComparisonWorkload<ParentDescriptor, Functor>::ExecuteImpl(const char* debugString) const +template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString> +void RefComparisonWorkload<Functor, ParentDescriptor, DebugString>::Execute() const { - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString); + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, StringMapping::Instance().Get(DebugString)); + const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); + const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]); + const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); - auto data = BaseFloat32ComparisonWorkload<ParentDescriptor>::GetData(); - const TensorShape& inShape0 = GetTensorInfo(data.m_Inputs[0]).GetShape(); - const TensorShape& inShape1 = GetTensorInfo(data.m_Inputs[1]).GetShape(); - const TensorShape& outputShape = GetTensorInfo(data.m_Outputs[0]).GetShape(); + const TensorShape& inShape0 = inputInfo0.GetShape(); + const TensorShape& inShape1 = inputInfo1.GetShape(); + const TensorShape& outShape = outputInfo.GetShape(); - const float* inData0 = GetInputTensorDataFloat(0, data); - const float* inData1 = GetInputTensorDataFloat(1, data); - uint8_t* outData = GetOutputTensorData<uint8_t>(0, data); + switch(inputInfo0.GetDataType()) + { + case armnn::DataType::QuantisedAsymm8: + { + QASymm8Decoder decodeIterator0(GetInputTensorDataU8(0, m_Data), + inputInfo0.GetQuantizationScale(), + inputInfo0.GetQuantizationOffset()); - ElementwiseFunction<Functor, float, uint8_t>(inShape0, - inShape1, - outputShape, - inData0, - inData1, - outData); + QASymm8Decoder decodeIterator1(GetInputTensorDataU8(1, m_Data), + inputInfo1.GetQuantizationScale(), + inputInfo1.GetQuantizationOffset()); -} - -template<typename ParentDescriptor, typename Functor> -void RefUint8ComparisonWorkload<ParentDescriptor, Functor>::ExecuteImpl(const char* debugString) const -{ - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString); - - auto data = BaseUint8ComparisonWorkload<ParentDescriptor>::GetData(); - const TensorShape& inputInfo0 = GetTensorInfo(data.m_Inputs[0]).GetShape(); - const TensorShape& inputInfo1 = GetTensorInfo(data.m_Inputs[1]).GetShape(); - const TensorShape& outputShape = GetTensorInfo(data.m_Outputs[0]).GetShape(); + BooleanEncoder encodeIterator0(GetOutputTensorDataU8(0, m_Data)); - const uint8_t* inData0 = GetInputTensorData<uint8_t>(0, data); - const uint8_t* inData1 = GetInputTensorData<uint8_t>(1, data); - uint8_t* outData = GetOutputTensorData<uint8_t>(0, data); + ElementwiseFunction<Functor, Decoder, ComparisonEncoder>(inShape0, + inShape1, + outShape, + decodeIterator0, + decodeIterator1, + encodeIterator0); + break; + } + case armnn::DataType::Float32: + { + FloatDecoder decodeIterator0(GetInputTensorDataFloat(0, m_Data)); + FloatDecoder decodeIterator1(GetInputTensorDataFloat(1, m_Data)); + BooleanEncoder encodeIterator0(GetOutputTensorDataU8(0, m_Data)); - ElementwiseFunction<Functor, uint8_t, uint8_t>(inputInfo0, - inputInfo1, - outputShape, - inData0, - inData1, - outData); + ElementwiseFunction<Functor, Decoder, ComparisonEncoder>(inShape0, + inShape1, + outShape, + decodeIterator0, + decodeIterator1, + encodeIterator0); + break; + } + default: + BOOST_ASSERT_MSG(false, "RefComparisonWorkload: Not supported Data Type!"); + break; + } } } -template class armnn::RefFloat32ComparisonWorkload<armnn::EqualQueueDescriptor, std::equal_to<float>>; -template class armnn::RefUint8ComparisonWorkload<armnn::EqualQueueDescriptor, std::equal_to<uint8_t>>; +template class armnn::RefComparisonWorkload<std::equal_to<float>, + armnn::EqualQueueDescriptor, + armnn::StringMapping::RefEqualWorkload_Execute>; -template class armnn::RefFloat32ComparisonWorkload<armnn::GreaterQueueDescriptor, std::greater<float>>; -template class armnn::RefUint8ComparisonWorkload<armnn::GreaterQueueDescriptor, std::greater<uint8_t>>; +template class armnn::RefComparisonWorkload<std::greater<float>, + armnn::GreaterQueueDescriptor, + armnn::StringMapping::RefGreaterWorkload_Execute>; diff --git a/src/backends/reference/workloads/RefComparisonWorkload.hpp b/src/backends/reference/workloads/RefComparisonWorkload.hpp index 524d20625a..cfc2dcf2aa 100644 --- a/src/backends/reference/workloads/RefComparisonWorkload.hpp +++ b/src/backends/reference/workloads/RefComparisonWorkload.hpp @@ -13,80 +13,24 @@ namespace armnn { -template <typename Functor, - typename armnn::DataType DataType, - typename ParentDescriptor, - typename armnn::StringMapping::Id DebugString> -class RefComparisonWorkload -{ - // Needs specialization. The default is empty on purpose. -}; - -template <typename ParentDescriptor, typename Functor> -class RefFloat32ComparisonWorkload : public BaseFloat32ComparisonWorkload<ParentDescriptor> -{ -public: - using BaseFloat32ComparisonWorkload<ParentDescriptor>::BaseFloat32ComparisonWorkload; - void ExecuteImpl(const char * debugString) const; -}; - -template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString> -class RefComparisonWorkload<Functor, armnn::DataType::Float32, ParentDescriptor, DebugString> - : public RefFloat32ComparisonWorkload<ParentDescriptor, Functor> -{ -public: - using RefFloat32ComparisonWorkload<ParentDescriptor, Functor>::RefFloat32ComparisonWorkload; - - virtual void Execute() const override - { - using Parent = RefFloat32ComparisonWorkload<ParentDescriptor, Functor>; - Parent::ExecuteImpl(StringMapping::Instance().Get(DebugString)); - } -}; - -template <typename ParentDescriptor, typename Functor> -class RefUint8ComparisonWorkload : public BaseUint8ComparisonWorkload<ParentDescriptor> -{ -public: - using BaseUint8ComparisonWorkload<ParentDescriptor>::BaseUint8ComparisonWorkload; - void ExecuteImpl(const char * debugString) const; -}; - template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString> -class RefComparisonWorkload<Functor, armnn::DataType::QuantisedAsymm8, ParentDescriptor, DebugString> - : public RefUint8ComparisonWorkload<ParentDescriptor, Functor> +class RefComparisonWorkload : public BaseWorkload<ParentDescriptor> { public: - using RefUint8ComparisonWorkload<ParentDescriptor, Functor>::RefUint8ComparisonWorkload; + using BaseWorkload<ParentDescriptor>::m_Data; + using BaseWorkload<ParentDescriptor>::BaseWorkload; - virtual void Execute() const override - { - using Parent = RefUint8ComparisonWorkload<ParentDescriptor, Functor>; - Parent::ExecuteImpl(StringMapping::Instance().Get(DebugString)); - } + void Execute() const override; }; -using RefEqualFloat32Workload = +using RefEqualWorkload = RefComparisonWorkload<std::equal_to<float>, - DataType::Float32, - EqualQueueDescriptor, - StringMapping::RefEqualWorkload_Execute>; + EqualQueueDescriptor, + StringMapping::RefEqualWorkload_Execute>; -using RefEqualUint8Workload = - RefComparisonWorkload<std::equal_to<uint8_t>, - DataType::QuantisedAsymm8, - EqualQueueDescriptor, - StringMapping::RefEqualWorkload_Execute>; -using RefGreaterFloat32Workload = +using RefGreaterWorkload = RefComparisonWorkload<std::greater<float>, - DataType::Float32, - GreaterQueueDescriptor, - StringMapping::RefGreaterWorkload_Execute>; - -using RefGreaterUint8Workload = - RefComparisonWorkload<std::greater<uint8_t>, - DataType::QuantisedAsymm8, - GreaterQueueDescriptor, - StringMapping::RefGreaterWorkload_Execute>; + GreaterQueueDescriptor, + StringMapping::RefGreaterWorkload_Execute>; } // armnn diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp index 356d7a0c16..6e6e1d5f21 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp @@ -14,14 +14,10 @@ namespace armnn { -template <typename Functor, - typename armnn::DataType DataType, - typename ParentDescriptor, - typename armnn::StringMapping::Id DebugString> -void RefElementwiseWorkload<Functor, DataType, ParentDescriptor, DebugString>::Execute() const +template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString> +void RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::Execute() const { ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, StringMapping::Instance().Get(DebugString)); - const TensorInfo& inputInfo0 = GetTensorInfo(m_Data.m_Inputs[0]); const TensorInfo& inputInfo1 = GetTensorInfo(m_Data.m_Inputs[1]); const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]); @@ -30,32 +26,46 @@ void RefElementwiseWorkload<Functor, DataType, ParentDescriptor, DebugString>::E const TensorShape& inShape1 = inputInfo1.GetShape(); const TensorShape& outShape = outputInfo.GetShape(); - switch(DataType) + switch(inputInfo0.GetDataType()) { case armnn::DataType::QuantisedAsymm8: { - std::vector<float> results(outputInfo.GetNumElements()); - ElementwiseFunction<Functor, float, float>(inShape0, - inShape1, - outShape, - Dequantize(GetInputTensorDataU8(0, m_Data), inputInfo0).data(), - Dequantize(GetInputTensorDataU8(1, m_Data), inputInfo1).data(), - results.data()); - Quantize(GetOutputTensorDataU8(0, m_Data), results.data(), outputInfo); + QASymm8Decoder decodeIterator0(GetInputTensorDataU8(0, m_Data), + inputInfo0.GetQuantizationScale(), + inputInfo0.GetQuantizationOffset()); + + QASymm8Decoder decodeIterator1(GetInputTensorDataU8(1, m_Data), + inputInfo1.GetQuantizationScale(), + inputInfo1.GetQuantizationOffset()); + + QASymm8Encoder encodeIterator0(GetOutputTensorDataU8(0, m_Data), + outputInfo.GetQuantizationScale(), + outputInfo.GetQuantizationOffset()); + + ElementwiseFunction<Functor, Decoder, Encoder>(inShape0, + inShape1, + outShape, + decodeIterator0, + decodeIterator1, + encodeIterator0); break; } case armnn::DataType::Float32: { - ElementwiseFunction<Functor, float, float>(inShape0, - inShape1, - outShape, - GetInputTensorDataFloat(0, m_Data), - GetInputTensorDataFloat(1, m_Data), - GetOutputTensorDataFloat(0, m_Data)); + FloatDecoder decodeIterator0(GetInputTensorDataFloat(0, m_Data)); + FloatDecoder decodeIterator1(GetInputTensorDataFloat(1, m_Data)); + FloatEncoder encodeIterator0(GetOutputTensorDataFloat(0, m_Data)); + + ElementwiseFunction<Functor, Decoder, Encoder>(inShape0, + inShape1, + outShape, + decodeIterator0, + decodeIterator1, + encodeIterator0); break; } default: - BOOST_ASSERT_MSG(false, "Unknown Data Type!"); + BOOST_ASSERT_MSG(false, "RefElementwiseWorkload: Not supported Data Type!"); break; } } @@ -63,62 +73,25 @@ void RefElementwiseWorkload<Functor, DataType, ParentDescriptor, DebugString>::E } template class armnn::RefElementwiseWorkload<std::plus<float>, - armnn::DataType::Float32, - armnn::AdditionQueueDescriptor, - armnn::StringMapping::RefAdditionWorkload_Execute>; - -template class armnn::RefElementwiseWorkload<std::plus<float>, - armnn::DataType::QuantisedAsymm8, - armnn::AdditionQueueDescriptor, - armnn::StringMapping::RefAdditionWorkload_Execute>; + armnn::AdditionQueueDescriptor, + armnn::StringMapping::RefAdditionWorkload_Execute>; template class armnn::RefElementwiseWorkload<std::minus<float>, - armnn::DataType::Float32, - armnn::SubtractionQueueDescriptor, - armnn::StringMapping::RefSubtractionWorkload_Execute>; - -template class armnn::RefElementwiseWorkload<std::minus<float>, - armnn::DataType::QuantisedAsymm8, - armnn::SubtractionQueueDescriptor, - armnn::StringMapping::RefSubtractionWorkload_Execute>; - -template class armnn::RefElementwiseWorkload<std::multiplies<float>, - armnn::DataType::Float32, - armnn::MultiplicationQueueDescriptor, - armnn::StringMapping::RefMultiplicationWorkload_Execute>; + armnn::SubtractionQueueDescriptor, + armnn::StringMapping::RefSubtractionWorkload_Execute>; template class armnn::RefElementwiseWorkload<std::multiplies<float>, - armnn::DataType::QuantisedAsymm8, - armnn::MultiplicationQueueDescriptor, - armnn::StringMapping::RefMultiplicationWorkload_Execute>; + armnn::MultiplicationQueueDescriptor, + armnn::StringMapping::RefMultiplicationWorkload_Execute>; template class armnn::RefElementwiseWorkload<std::divides<float>, - armnn::DataType::Float32, - armnn::DivisionQueueDescriptor, - armnn::StringMapping::RefDivisionWorkload_Execute>; - -template class armnn::RefElementwiseWorkload<std::divides<float>, - armnn::DataType::QuantisedAsymm8, - armnn::DivisionQueueDescriptor, - armnn::StringMapping::RefDivisionWorkload_Execute>; - -template class armnn::RefElementwiseWorkload<armnn::maximum<float>, - armnn::DataType::Float32, - armnn::MaximumQueueDescriptor, - armnn::StringMapping::RefMaximumWorkload_Execute>; + armnn::DivisionQueueDescriptor, + armnn::StringMapping::RefDivisionWorkload_Execute>; template class armnn::RefElementwiseWorkload<armnn::maximum<float>, - armnn::DataType::QuantisedAsymm8, - armnn::MaximumQueueDescriptor, - armnn::StringMapping::RefMaximumWorkload_Execute>; - - -template class armnn::RefElementwiseWorkload<armnn::minimum<float>, - armnn::DataType::Float32, - armnn::MinimumQueueDescriptor, - armnn::StringMapping::RefMinimumWorkload_Execute>; + armnn::MaximumQueueDescriptor, + armnn::StringMapping::RefMaximumWorkload_Execute>; template class armnn::RefElementwiseWorkload<armnn::minimum<float>, - armnn::DataType::QuantisedAsymm8, - armnn::MinimumQueueDescriptor, - armnn::StringMapping::RefMinimumWorkload_Execute>; + armnn::MinimumQueueDescriptor, + armnn::StringMapping::RefMinimumWorkload_Execute>;
\ No newline at end of file diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp index 371904977a..81af19627e 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp @@ -15,90 +15,43 @@ namespace armnn { -template <typename Functor, - typename armnn::DataType DataType, - typename ParentDescriptor, - typename armnn::StringMapping::Id DebugString> -class RefElementwiseWorkload - : public TypedWorkload<ParentDescriptor, DataType> +template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString> +class RefElementwiseWorkload : public BaseWorkload<ParentDescriptor> { public: - - using TypedWorkload<ParentDescriptor, DataType>::m_Data; - using TypedWorkload<ParentDescriptor, DataType>::TypedWorkload; + using BaseWorkload<ParentDescriptor>::m_Data; + using BaseWorkload<ParentDescriptor>::BaseWorkload; void Execute() const override; }; -using RefAdditionFloat32Workload = - RefElementwiseWorkload<std::plus<float>, - DataType::Float32, - AdditionQueueDescriptor, - StringMapping::RefAdditionWorkload_Execute>; - -using RefAdditionUint8Workload = +using RefAdditionWorkload = RefElementwiseWorkload<std::plus<float>, - DataType::QuantisedAsymm8, AdditionQueueDescriptor, StringMapping::RefAdditionWorkload_Execute>; -using RefSubtractionFloat32Workload = - RefElementwiseWorkload<std::minus<float>, - DataType::Float32, - SubtractionQueueDescriptor, - StringMapping::RefSubtractionWorkload_Execute>; - -using RefSubtractionUint8Workload = +using RefSubtractionWorkload = RefElementwiseWorkload<std::minus<float>, - DataType::QuantisedAsymm8, SubtractionQueueDescriptor, StringMapping::RefSubtractionWorkload_Execute>; -using RefMultiplicationFloat32Workload = - RefElementwiseWorkload<std::multiplies<float>, - DataType::Float32, - MultiplicationQueueDescriptor, - StringMapping::RefMultiplicationWorkload_Execute>; - -using RefMultiplicationUint8Workload = +using RefMultiplicationWorkload = RefElementwiseWorkload<std::multiplies<float>, - DataType::QuantisedAsymm8, MultiplicationQueueDescriptor, StringMapping::RefMultiplicationWorkload_Execute>; -using RefDivisionFloat32Workload = - RefElementwiseWorkload<std::divides<float>, - DataType::Float32, - DivisionQueueDescriptor, - StringMapping::RefDivisionWorkload_Execute>; - -using RefDivisionUint8Workload = +using RefDivisionWorkload = RefElementwiseWorkload<std::divides<float>, - DataType::QuantisedAsymm8, DivisionQueueDescriptor, StringMapping::RefDivisionWorkload_Execute>; -using RefMaximumFloat32Workload = - RefElementwiseWorkload<armnn::maximum<float>, - DataType::Float32, - MaximumQueueDescriptor, - StringMapping::RefMaximumWorkload_Execute>; - -using RefMaximumUint8Workload = +using RefMaximumWorkload = RefElementwiseWorkload<armnn::maximum<float>, - DataType::QuantisedAsymm8, MaximumQueueDescriptor, StringMapping::RefMaximumWorkload_Execute>; -using RefMinimumFloat32Workload = - RefElementwiseWorkload<minimum<float>, - DataType::Float32, - MinimumQueueDescriptor, - StringMapping::RefMinimumWorkload_Execute>; - -using RefMinimumUint8Workload = - RefElementwiseWorkload<minimum<float>, - DataType::QuantisedAsymm8, +using RefMinimumWorkload = + RefElementwiseWorkload<armnn::minimum<float>, MinimumQueueDescriptor, StringMapping::RefMinimumWorkload_Execute>; } // armnn |