diff options
Diffstat (limited to 'src/backends/reference/workloads')
8 files changed, 188 insertions, 56 deletions
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index f95fda08d1..57e89fa456 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -40,6 +40,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefBatchToSpaceNdFloat32Workload.hpp RefBatchToSpaceNdUint8Workload.cpp RefBatchToSpaceNdUint8Workload.hpp + RefComparisonWorkload.cpp + RefComparisonWorkload.hpp RefConstantWorkload.cpp RefConstantWorkload.hpp RefConvertFp16ToFp32Workload.cpp diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp index cb8aa7089c..c8c25ef9e9 100644 --- a/src/backends/reference/workloads/ElementwiseFunction.cpp +++ b/src/backends/reference/workloads/ElementwiseFunction.cpp @@ -13,24 +13,26 @@ namespace armnn { -template <typename Functor> -ElementwiseFunction<Functor>::ElementwiseFunction(const TensorShape& inShape0, - const TensorShape& inShape1, - const TensorShape& outShape, - const float* inData0, - const float* inData1, - float* outData) +template <typename Functor, typename dataTypeInput, typename dataTypeOutput> +ElementwiseFunction<Functor, dataTypeInput, dataTypeOutput>::ElementwiseFunction(const TensorShape& inShape0, + const TensorShape& inShape1, + const TensorShape& outShape, + const dataTypeInput* inData0, + const dataTypeInput* inData1, + dataTypeOutput* outData) { BroadcastLoop(inShape0, inShape1, outShape).Unroll(Functor(), 0, inData0, inData1, outData); } } //namespace armnn -template struct armnn::ElementwiseFunction<std::plus<float>>; -template struct armnn::ElementwiseFunction<std::minus<float>>; -template struct armnn::ElementwiseFunction<std::multiplies<float>>; -template struct armnn::ElementwiseFunction<std::divides<float>>; -template struct armnn::ElementwiseFunction<armnn::maximum<float>>; -template struct armnn::ElementwiseFunction<armnn::minimum<float>>; -template struct armnn::ElementwiseFunction<std::equal_to<float>>; -template struct armnn::ElementwiseFunction<std::greater<float>>; +template struct armnn::ElementwiseFunction<std::plus<float>, float, float>; +template struct armnn::ElementwiseFunction<std::minus<float>, float, float>; +template struct armnn::ElementwiseFunction<std::multiplies<float>, float, float>; +template struct armnn::ElementwiseFunction<std::divides<float>, float, float>; +template struct armnn::ElementwiseFunction<armnn::maximum<float>, float, float>; +template struct armnn::ElementwiseFunction<armnn::minimum<float>, float, float>; +template struct armnn::ElementwiseFunction<std::equal_to<float>, float ,uint8_t>; +template struct armnn::ElementwiseFunction<std::equal_to<uint8_t>, uint8_t, uint8_t>; +template struct armnn::ElementwiseFunction<std::greater<float>, float, uint8_t>; +template struct armnn::ElementwiseFunction<std::greater<uint8_t>, uint8_t, uint8_t>; diff --git a/src/backends/reference/workloads/ElementwiseFunction.hpp b/src/backends/reference/workloads/ElementwiseFunction.hpp index 0ac136466c..8099f3279a 100644 --- a/src/backends/reference/workloads/ElementwiseFunction.hpp +++ b/src/backends/reference/workloads/ElementwiseFunction.hpp @@ -10,15 +10,15 @@ namespace armnn { -template <typename Functor> +template <typename Functor, typename dataTypeInput, typename dataTypeOutput> struct ElementwiseFunction { ElementwiseFunction(const TensorShape& inShape0, const TensorShape& inShape1, const TensorShape& outShape, - const float* inData0, - const float* inData1, - float* outData); + const dataTypeInput* inData0, + const dataTypeInput* inData1, + dataTypeOutput* outData); }; } //namespace armnn diff --git a/src/backends/reference/workloads/RefComparisonWorkload.cpp b/src/backends/reference/workloads/RefComparisonWorkload.cpp new file mode 100644 index 0000000000..fe517ff51a --- /dev/null +++ b/src/backends/reference/workloads/RefComparisonWorkload.cpp @@ -0,0 +1,65 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefComparisonWorkload.hpp" +#include "ElementwiseFunction.hpp" +#include "RefWorkloadUtils.hpp" +#include "Profiling.hpp" +#include <vector> + +namespace armnn { + +template<typename ParentDescriptor, typename Functor> +void RefFloat32ComparisonWorkload<ParentDescriptor, Functor>::ExecuteImpl(const char* debugString) const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString); + + auto data = BaseFloat32ComparisonWorkload<ParentDescriptor>::GetData(); + const TensorShape& inShape0 = GetTensorInfo(data.m_Inputs[0]).GetShape(); + const TensorShape& inShape1 = GetTensorInfo(data.m_Inputs[1]).GetShape(); + const TensorShape& outputShape = GetTensorInfo(data.m_Outputs[0]).GetShape(); + + const float* inData0 = GetInputTensorDataFloat(0, data); + const float* inData1 = GetInputTensorDataFloat(1, data); + uint8_t* outData = GetOutputTensorData<uint8_t>(0, data); + + ElementwiseFunction<Functor, float, uint8_t>(inShape0, + inShape1, + outputShape, + inData0, + inData1, + outData); + +} + +template<typename ParentDescriptor, typename Functor> +void RefUint8ComparisonWorkload<ParentDescriptor, Functor>::ExecuteImpl(const char* debugString) const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString); + + auto data = BaseUint8ComparisonWorkload<ParentDescriptor>::GetData(); + const TensorShape& inputInfo0 = GetTensorInfo(data.m_Inputs[0]).GetShape(); + const TensorShape& inputInfo1 = GetTensorInfo(data.m_Inputs[1]).GetShape(); + const TensorShape& outputShape = GetTensorInfo(data.m_Outputs[0]).GetShape(); + + const uint8_t* inData0 = GetInputTensorData<uint8_t>(0, data); + const uint8_t* inData1 = GetInputTensorData<uint8_t>(1, data); + uint8_t* outData = GetOutputTensorData<uint8_t>(0, data); + + ElementwiseFunction<Functor, uint8_t, uint8_t>(inputInfo0, + inputInfo1, + outputShape, + inData0, + inData1, + outData); +} + +} + +template class armnn::RefFloat32ComparisonWorkload<armnn::EqualQueueDescriptor, std::equal_to<float>>; +template class armnn::RefUint8ComparisonWorkload<armnn::EqualQueueDescriptor, std::equal_to<uint8_t>>; + +template class armnn::RefFloat32ComparisonWorkload<armnn::GreaterQueueDescriptor, std::greater<float>>; +template class armnn::RefUint8ComparisonWorkload<armnn::GreaterQueueDescriptor, std::greater<uint8_t>>; diff --git a/src/backends/reference/workloads/RefComparisonWorkload.hpp b/src/backends/reference/workloads/RefComparisonWorkload.hpp new file mode 100644 index 0000000000..524d20625a --- /dev/null +++ b/src/backends/reference/workloads/RefComparisonWorkload.hpp @@ -0,0 +1,92 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <armnn/Types.hpp> +#include <backendsCommon/Workload.hpp> +#include <backendsCommon/WorkloadData.hpp> +#include "StringMapping.hpp" + +namespace armnn +{ + +template <typename Functor, + typename armnn::DataType DataType, + typename ParentDescriptor, + typename armnn::StringMapping::Id DebugString> +class RefComparisonWorkload +{ + // Needs specialization. The default is empty on purpose. +}; + +template <typename ParentDescriptor, typename Functor> +class RefFloat32ComparisonWorkload : public BaseFloat32ComparisonWorkload<ParentDescriptor> +{ +public: + using BaseFloat32ComparisonWorkload<ParentDescriptor>::BaseFloat32ComparisonWorkload; + void ExecuteImpl(const char * debugString) const; +}; + +template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString> +class RefComparisonWorkload<Functor, armnn::DataType::Float32, ParentDescriptor, DebugString> + : public RefFloat32ComparisonWorkload<ParentDescriptor, Functor> +{ +public: + using RefFloat32ComparisonWorkload<ParentDescriptor, Functor>::RefFloat32ComparisonWorkload; + + virtual void Execute() const override + { + using Parent = RefFloat32ComparisonWorkload<ParentDescriptor, Functor>; + Parent::ExecuteImpl(StringMapping::Instance().Get(DebugString)); + } +}; + +template <typename ParentDescriptor, typename Functor> +class RefUint8ComparisonWorkload : public BaseUint8ComparisonWorkload<ParentDescriptor> +{ +public: + using BaseUint8ComparisonWorkload<ParentDescriptor>::BaseUint8ComparisonWorkload; + void ExecuteImpl(const char * debugString) const; +}; + +template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString> +class RefComparisonWorkload<Functor, armnn::DataType::QuantisedAsymm8, ParentDescriptor, DebugString> + : public RefUint8ComparisonWorkload<ParentDescriptor, Functor> +{ +public: + using RefUint8ComparisonWorkload<ParentDescriptor, Functor>::RefUint8ComparisonWorkload; + + virtual void Execute() const override + { + using Parent = RefUint8ComparisonWorkload<ParentDescriptor, Functor>; + Parent::ExecuteImpl(StringMapping::Instance().Get(DebugString)); + } +}; + +using RefEqualFloat32Workload = + RefComparisonWorkload<std::equal_to<float>, + DataType::Float32, + EqualQueueDescriptor, + StringMapping::RefEqualWorkload_Execute>; + +using RefEqualUint8Workload = + RefComparisonWorkload<std::equal_to<uint8_t>, + DataType::QuantisedAsymm8, + EqualQueueDescriptor, + StringMapping::RefEqualWorkload_Execute>; + +using RefGreaterFloat32Workload = + RefComparisonWorkload<std::greater<float>, + DataType::Float32, + GreaterQueueDescriptor, + StringMapping::RefGreaterWorkload_Execute>; + +using RefGreaterUint8Workload = + RefComparisonWorkload<std::greater<uint8_t>, + DataType::QuantisedAsymm8, + GreaterQueueDescriptor, + StringMapping::RefGreaterWorkload_Execute>; +} // armnn diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp index 13d6e70a96..c9b93c8524 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp @@ -26,7 +26,7 @@ void BaseFloat32ElementwiseWorkload<ParentDescriptor, Functor>::ExecuteImpl(cons const float* inData1 = GetInputTensorDataFloat(1, data); float* outData = GetOutputTensorDataFloat(0, data); - ElementwiseFunction<Functor>(inShape0, inShape1, outShape, inData0, inData1, outData); + ElementwiseFunction<Functor, float, float>(inShape0, inShape1, outShape, inData0, inData1, outData); } template <typename ParentDescriptor, typename Functor> @@ -44,12 +44,12 @@ void BaseUint8ElementwiseWorkload<ParentDescriptor, Functor>::ExecuteImpl(const std::vector<float> results(outputInfo.GetNumElements()); - ElementwiseFunction<Functor>(inputInfo0.GetShape(), - inputInfo1.GetShape(), - outputInfo.GetShape(), - dequant0.data(), - dequant1.data(), - results.data()); + ElementwiseFunction<Functor, float, float>(inputInfo0.GetShape(), + inputInfo1.GetShape(), + outputInfo.GetShape(), + dequant0.data(), + dequant1.data(), + results.data()); Quantize(GetOutputTensorDataU8(0, data), results.data(), outputInfo); } @@ -73,9 +73,3 @@ template class armnn::BaseUint8ElementwiseWorkload<armnn::MaximumQueueDescriptor template class armnn::BaseFloat32ElementwiseWorkload<armnn::MinimumQueueDescriptor, armnn::minimum<float>>; template class armnn::BaseUint8ElementwiseWorkload<armnn::MinimumQueueDescriptor, armnn::minimum<float>>; - -template class armnn::BaseFloat32ElementwiseWorkload<armnn::EqualQueueDescriptor, std::equal_to<float>>; -template class armnn::BaseUint8ElementwiseWorkload<armnn::EqualQueueDescriptor, std::equal_to<float>>; - -template class armnn::BaseFloat32ElementwiseWorkload<armnn::GreaterQueueDescriptor, std::greater<float>>; -template class armnn::BaseUint8ElementwiseWorkload<armnn::GreaterQueueDescriptor, std::greater<float>>; diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp index 6dd6865f53..a5ff376673 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp @@ -144,28 +144,4 @@ using RefMinimumUint8Workload = DataType::QuantisedAsymm8, MinimumQueueDescriptor, StringMapping::RefMinimumWorkload_Execute>; - -using RefEqualFloat32Workload = - RefElementwiseWorkload<std::equal_to<float>, - DataType::Float32, - EqualQueueDescriptor, - StringMapping::RefEqualWorkload_Execute>; - -using RefEqualUint8Workload = - RefElementwiseWorkload<std::equal_to<float>, - DataType::QuantisedAsymm8, - EqualQueueDescriptor, - StringMapping::RefEqualWorkload_Execute>; - -using RefGreaterFloat32Workload = - RefElementwiseWorkload<std::greater<float>, - DataType::Float32, - GreaterQueueDescriptor, - StringMapping::RefGreaterWorkload_Execute>; - -using RefGreaterUint8Workload = - RefElementwiseWorkload<std::greater<float>, - DataType::QuantisedAsymm8, - GreaterQueueDescriptor, - StringMapping::RefGreaterWorkload_Execute>; } // armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index 1cbceb366b..d9f4dbb342 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -60,3 +60,4 @@ #include "RefBatchToSpaceNdFloat32Workload.hpp" #include "RefDebugWorkload.hpp" #include "RefRsqrtFloat32Workload.hpp" +#include "RefComparisonWorkload.hpp" |