From d57415d9a2117da9cc5c58f8b5e39ba7455417d1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89anna=20=C3=93=20Cath=C3=A1in?= Date: Wed, 28 Nov 2018 16:24:38 +0000 Subject: IVGCVSW-2202 Refactoring Arithmetic* names to Elementwise* names for workloads and workload functions Change-Id: I6f3fce12a55f7d38ceafcdfcd6b5181bf56e2c09 --- src/backends/README.md | 2 +- src/backends/cl/test/ClCreateWorkloadTests.cpp | 24 ++-- src/backends/neon/test/NeonCreateWorkloadTests.cpp | 16 +-- src/backends/reference/backend.mk | 4 +- .../reference/test/RefCreateWorkloadTests.cpp | 69 ++++++------ .../reference/workloads/ArithmeticFunction.cpp | 29 ----- .../reference/workloads/ArithmeticFunction.hpp | 24 ---- src/backends/reference/workloads/CMakeLists.txt | 8 +- .../reference/workloads/ElementwiseFunction.cpp | 29 +++++ .../reference/workloads/ElementwiseFunction.hpp | 24 ++++ .../reference/workloads/RefArithmeticWorkload.cpp | 69 ------------ .../reference/workloads/RefArithmeticWorkload.hpp | 122 --------------------- .../reference/workloads/RefElementwiseWorkload.cpp | 69 ++++++++++++ .../reference/workloads/RefElementwiseWorkload.hpp | 122 +++++++++++++++++++++ src/backends/reference/workloads/RefWorkloads.hpp | 4 +- 15 files changed, 308 insertions(+), 307 deletions(-) delete mode 100644 src/backends/reference/workloads/ArithmeticFunction.cpp delete mode 100644 src/backends/reference/workloads/ArithmeticFunction.hpp create mode 100644 src/backends/reference/workloads/ElementwiseFunction.cpp create mode 100644 src/backends/reference/workloads/ElementwiseFunction.hpp delete mode 100644 src/backends/reference/workloads/RefArithmeticWorkload.cpp delete mode 100644 src/backends/reference/workloads/RefArithmeticWorkload.hpp create mode 100644 src/backends/reference/workloads/RefElementwiseWorkload.cpp create mode 100644 src/backends/reference/workloads/RefElementwiseWorkload.hpp (limited to 'src/backends') diff --git a/src/backends/README.md b/src/backends/README.md index ddd1bb6b92..60e4d0baa7 100644 --- a/src/backends/README.md +++ b/src/backends/README.md @@ -68,7 +68,7 @@ BACKEND_SOURCES := \ RefLayerSupport.cpp \ RefWorkloadFactory.cpp \ workloads/Activation.cpp \ - workloads/ArithmeticFunction.cpp \ + workloads/ElementwiseFunction.cpp \ workloads/Broadcast.cpp \ ... diff --git a/src/backends/cl/test/ClCreateWorkloadTests.cpp b/src/backends/cl/test/ClCreateWorkloadTests.cpp index 8cef9d78b1..c5f685dc94 100644 --- a/src/backends/cl/test/ClCreateWorkloadTests.cpp +++ b/src/backends/cl/test/ClCreateWorkloadTests.cpp @@ -55,15 +55,15 @@ template -static void ClCreateArithmethicWorkloadTest() +static void ClCreateElementwiseWorkloadTest() { Graph graph; ClWorkloadFactory factory = ClWorkloadFactoryHelper::GetFactory(ClWorkloadFactoryHelper::GetMemoryManager()); - auto workload = CreateArithmeticWorkloadTest(factory, graph); + auto workload = CreateElementwiseWorkloadTest(factory, graph); - // Checks that inputs/outputs are as we expect them (see definition of CreateArithmeticWorkloadTest). + // Checks that inputs/outputs are as we expect them (see definition of CreateElementwiseWorkloadTest). DescriptorType queueDescriptor = workload->GetData(); auto inputHandle1 = boost::polymorphic_downcast(queueDescriptor.m_Inputs[0]); auto inputHandle2 = boost::polymorphic_downcast(queueDescriptor.m_Inputs[1]); @@ -75,7 +75,7 @@ static void ClCreateArithmethicWorkloadTest() BOOST_AUTO_TEST_CASE(CreateAdditionFloatWorkload) { - ClCreateArithmethicWorkloadTest(); @@ -83,7 +83,7 @@ BOOST_AUTO_TEST_CASE(CreateAdditionFloatWorkload) BOOST_AUTO_TEST_CASE(CreateAdditionFloat16Workload) { - ClCreateArithmethicWorkloadTest(); @@ -91,7 +91,7 @@ BOOST_AUTO_TEST_CASE(CreateAdditionFloat16Workload) BOOST_AUTO_TEST_CASE(CreateSubtractionFloatWorkload) { - ClCreateArithmethicWorkloadTest(); @@ -99,7 +99,7 @@ BOOST_AUTO_TEST_CASE(CreateSubtractionFloatWorkload) BOOST_AUTO_TEST_CASE(CreateSubtractionFloat16Workload) { - ClCreateArithmethicWorkloadTest(); @@ -107,7 +107,7 @@ BOOST_AUTO_TEST_CASE(CreateSubtractionFloat16Workload) BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkloadTest) { - ClCreateArithmethicWorkloadTest(); @@ -115,7 +115,7 @@ BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkloadTest) BOOST_AUTO_TEST_CASE(CreateMultiplicationFloat16WorkloadTest) { - ClCreateArithmethicWorkloadTest(); @@ -123,7 +123,7 @@ BOOST_AUTO_TEST_CASE(CreateMultiplicationFloat16WorkloadTest) BOOST_AUTO_TEST_CASE(CreateMultiplicationUint8WorkloadTest) { - ClCreateArithmethicWorkloadTest(); @@ -131,7 +131,7 @@ BOOST_AUTO_TEST_CASE(CreateMultiplicationUint8WorkloadTest) BOOST_AUTO_TEST_CASE(CreateDivisionFloatWorkloadTest) { - ClCreateArithmethicWorkloadTest(); @@ -139,7 +139,7 @@ BOOST_AUTO_TEST_CASE(CreateDivisionFloatWorkloadTest) BOOST_AUTO_TEST_CASE(CreateDivisionFloat16WorkloadTest) { - ClCreateArithmethicWorkloadTest(); diff --git a/src/backends/neon/test/NeonCreateWorkloadTests.cpp b/src/backends/neon/test/NeonCreateWorkloadTests.cpp index 120125311e..dc6ec16e49 100644 --- a/src/backends/neon/test/NeonCreateWorkloadTests.cpp +++ b/src/backends/neon/test/NeonCreateWorkloadTests.cpp @@ -87,13 +87,13 @@ template -static void NeonCreateArithmethicWorkloadTest() +static void NeonCreateElementwiseWorkloadTest() { Graph graph; NeonWorkloadFactory factory = NeonWorkloadFactoryHelper::GetFactory(NeonWorkloadFactoryHelper::GetMemoryManager()); - auto workload = CreateArithmeticWorkloadTest(factory, graph); + auto workload = CreateElementwiseWorkloadTest(factory, graph); DescriptorType queueDescriptor = workload->GetData(); auto inputHandle1 = boost::polymorphic_downcast(queueDescriptor.m_Inputs[0]); @@ -107,7 +107,7 @@ static void NeonCreateArithmethicWorkloadTest() #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC BOOST_AUTO_TEST_CASE(CreateAdditionFloat16Workload) { - NeonCreateArithmethicWorkloadTest(); @@ -116,7 +116,7 @@ BOOST_AUTO_TEST_CASE(CreateAdditionFloat16Workload) BOOST_AUTO_TEST_CASE(CreateAdditionFloatWorkload) { - NeonCreateArithmethicWorkloadTest(); @@ -125,7 +125,7 @@ BOOST_AUTO_TEST_CASE(CreateAdditionFloatWorkload) #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC BOOST_AUTO_TEST_CASE(CreateSubtractionFloat16Workload) { - NeonCreateArithmethicWorkloadTest(); @@ -134,7 +134,7 @@ BOOST_AUTO_TEST_CASE(CreateSubtractionFloat16Workload) BOOST_AUTO_TEST_CASE(CreateSubtractionFloatWorkload) { - NeonCreateArithmethicWorkloadTest(); @@ -143,7 +143,7 @@ BOOST_AUTO_TEST_CASE(CreateSubtractionFloatWorkload) #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC BOOST_AUTO_TEST_CASE(CreateMultiplicationFloat16Workload) { - NeonCreateArithmethicWorkloadTest(); @@ -152,7 +152,7 @@ BOOST_AUTO_TEST_CASE(CreateMultiplicationFloat16Workload) BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkload) { - NeonCreateArithmethicWorkloadTest(); diff --git a/src/backends/reference/backend.mk b/src/backends/reference/backend.mk index 7162d4a81e..66675bd2f9 100644 --- a/src/backends/reference/backend.mk +++ b/src/backends/reference/backend.mk @@ -12,17 +12,16 @@ BACKEND_SOURCES := \ RefLayerSupport.cpp \ RefWorkloadFactory.cpp \ workloads/Activation.cpp \ - workloads/ArithmeticFunction.cpp \ workloads/BatchToSpaceNd.cpp \ workloads/Broadcast.cpp \ workloads/ConvImpl.cpp \ + workloads/ElementwiseFunction.cpp \ workloads/FullyConnected.cpp \ workloads/Mean.cpp \ workloads/Pad.cpp \ workloads/Pooling2d.cpp \ workloads/RefActivationFloat32Workload.cpp \ workloads/RefActivationUint8Workload.cpp \ - workloads/RefArithmeticWorkload.cpp \ workloads/RefBaseConstantWorkload.cpp \ workloads/RefBatchNormalizationFloat32Workload.cpp \ workloads/RefBatchNormalizationUint8Workload.cpp \ @@ -36,6 +35,7 @@ BACKEND_SOURCES := \ workloads/RefConvolution2dUint8Workload.cpp \ workloads/RefDepthwiseConvolution2dFloat32Workload.cpp \ workloads/RefDepthwiseConvolution2dUint8Workload.cpp \ + workloads/RefElementwiseWorkload.cpp \ workloads/RefFakeQuantizationFloat32Workload.cpp \ workloads/RefFloorFloat32Workload.cpp \ workloads/RefFullyConnectedFloat32Workload.cpp \ diff --git a/src/backends/reference/test/RefCreateWorkloadTests.cpp b/src/backends/reference/test/RefCreateWorkloadTests.cpp index 47f9d0ef4e..8621122925 100644 --- a/src/backends/reference/test/RefCreateWorkloadTests.cpp +++ b/src/backends/reference/test/RefCreateWorkloadTests.cpp @@ -67,11 +67,12 @@ template -static void RefCreateArithmethicWorkloadTest() +static void RefCreateElementwiseWorkloadTest() { Graph graph; RefWorkloadFactory factory; - auto workload = CreateArithmeticWorkloadTest(factory, graph); + auto workload = CreateElementwiseWorkloadTest( + factory, graph); CheckInputsOutput(std::move(workload), TensorInfo({ 2, 3 }, DataType), @@ -81,66 +82,66 @@ static void RefCreateArithmethicWorkloadTest() BOOST_AUTO_TEST_CASE(CreateAdditionFloatWorkload) { - RefCreateArithmethicWorkloadTest(); + RefCreateElementwiseWorkloadTest(); } BOOST_AUTO_TEST_CASE(CreateAdditionUint8Workload) { - RefCreateArithmethicWorkloadTest(); + RefCreateElementwiseWorkloadTest(); } BOOST_AUTO_TEST_CASE(CreateSubtractionFloatWorkload) { - RefCreateArithmethicWorkloadTest(); + RefCreateElementwiseWorkloadTest(); } BOOST_AUTO_TEST_CASE(CreateSubtractionUint8Workload) { - RefCreateArithmethicWorkloadTest(); + RefCreateElementwiseWorkloadTest(); } BOOST_AUTO_TEST_CASE(CreateMultiplicationFloatWorkload) { - RefCreateArithmethicWorkloadTest(); + RefCreateElementwiseWorkloadTest(); } BOOST_AUTO_TEST_CASE(CreateMultiplicationUint8Workload) { - RefCreateArithmethicWorkloadTest(); + RefCreateElementwiseWorkloadTest(); } BOOST_AUTO_TEST_CASE(CreateDivisionFloatWorkload) { - RefCreateArithmethicWorkloadTest(); + RefCreateElementwiseWorkloadTest(); } BOOST_AUTO_TEST_CASE(CreateDivisionUint8Workload) { - RefCreateArithmethicWorkloadTest(); + RefCreateElementwiseWorkloadTest(); } template diff --git a/src/backends/reference/workloads/ArithmeticFunction.cpp b/src/backends/reference/workloads/ArithmeticFunction.cpp deleted file mode 100644 index fede138253..0000000000 --- a/src/backends/reference/workloads/ArithmeticFunction.cpp +++ /dev/null @@ -1,29 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "ArithmeticFunction.hpp" -#include "Broadcast.hpp" -#include - -namespace armnn -{ - -template -ArithmeticFunction::ArithmeticFunction(const TensorShape& inShape0, - const TensorShape& inShape1, - const TensorShape& outShape, - const float* inData0, - const float* inData1, - float* outData) -{ - BroadcastLoop(inShape0, inShape1, outShape).Unroll(Functor(), 0, inData0, inData1, outData); -} - -} //namespace armnn - -template struct armnn::ArithmeticFunction>; -template struct armnn::ArithmeticFunction>; -template struct armnn::ArithmeticFunction>; -template struct armnn::ArithmeticFunction>; diff --git a/src/backends/reference/workloads/ArithmeticFunction.hpp b/src/backends/reference/workloads/ArithmeticFunction.hpp deleted file mode 100644 index eafb6444f6..0000000000 --- a/src/backends/reference/workloads/ArithmeticFunction.hpp +++ /dev/null @@ -1,24 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include - -namespace armnn -{ - -template -struct ArithmeticFunction -{ - ArithmeticFunction(const TensorShape& inShape0, - const TensorShape& inShape1, - const TensorShape& outShape, - const float* inData0, - const float* inData1, - float* outData); -}; - -} //namespace armnn diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 2d9ad926f7..86c5f908b9 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -6,8 +6,6 @@ list(APPEND armnnRefBackendWorkloads_sources Activation.cpp Activation.hpp - ArithmeticFunction.cpp - ArithmeticFunction.hpp BatchNormImpl.hpp BatchToSpaceNd.cpp BatchToSpaceNd.hpp @@ -15,6 +13,8 @@ list(APPEND armnnRefBackendWorkloads_sources Broadcast.hpp ConvImpl.cpp ConvImpl.hpp + ElementwiseFunction.cpp + ElementwiseFunction.hpp FullyConnected.cpp FullyConnected.hpp Merger.hpp @@ -26,8 +26,6 @@ list(APPEND armnnRefBackendWorkloads_sources RefActivationFloat32Workload.hpp RefActivationUint8Workload.cpp RefActivationUint8Workload.hpp - RefArithmeticWorkload.cpp - RefArithmeticWorkload.hpp RefBaseConstantWorkload.cpp RefBaseConstantWorkload.hpp RefBatchNormalizationFloat32Workload.cpp @@ -50,6 +48,8 @@ list(APPEND armnnRefBackendWorkloads_sources RefConvolution2dFloat32Workload.hpp RefConvolution2dUint8Workload.cpp RefConvolution2dUint8Workload.hpp + RefElementwiseWorkload.cpp + RefElementwiseWorkload.hpp RefDepthwiseConvolution2dFloat32Workload.cpp RefDepthwiseConvolution2dFloat32Workload.hpp RefDepthwiseConvolution2dUint8Workload.cpp diff --git a/src/backends/reference/workloads/ElementwiseFunction.cpp b/src/backends/reference/workloads/ElementwiseFunction.cpp new file mode 100644 index 0000000000..bea3d2fb89 --- /dev/null +++ b/src/backends/reference/workloads/ElementwiseFunction.cpp @@ -0,0 +1,29 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ElementwiseFunction.hpp" +#include "Broadcast.hpp" +#include + +namespace armnn +{ + +template +ElementwiseFunction::ElementwiseFunction(const TensorShape& inShape0, + const TensorShape& inShape1, + const TensorShape& outShape, + const float* inData0, + const float* inData1, + float* outData) +{ + BroadcastLoop(inShape0, inShape1, outShape).Unroll(Functor(), 0, inData0, inData1, outData); +} + +} //namespace armnn + +template struct armnn::ElementwiseFunction>; +template struct armnn::ElementwiseFunction>; +template struct armnn::ElementwiseFunction>; +template struct armnn::ElementwiseFunction>; diff --git a/src/backends/reference/workloads/ElementwiseFunction.hpp b/src/backends/reference/workloads/ElementwiseFunction.hpp new file mode 100644 index 0000000000..5011616c0c --- /dev/null +++ b/src/backends/reference/workloads/ElementwiseFunction.hpp @@ -0,0 +1,24 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include + +namespace armnn +{ + +template +struct ElementwiseFunction +{ + ElementwiseFunction(const TensorShape& inShape0, + const TensorShape& inShape1, + const TensorShape& outShape, + const float* inData0, + const float* inData1, + float* outData); +}; + +} //namespace armnn diff --git a/src/backends/reference/workloads/RefArithmeticWorkload.cpp b/src/backends/reference/workloads/RefArithmeticWorkload.cpp deleted file mode 100644 index 6c39fa1186..0000000000 --- a/src/backends/reference/workloads/RefArithmeticWorkload.cpp +++ /dev/null @@ -1,69 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "RefArithmeticWorkload.hpp" -#include "ArithmeticFunction.hpp" -#include "RefWorkloadUtils.hpp" -#include "Profiling.hpp" -#include - -namespace armnn -{ - -template -void BaseFloat32ArithmeticWorkload::ExecuteImpl(const char * debugString) const -{ - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString); - - auto data = Float32Workload::GetData(); - const TensorShape& inShape0 = GetTensorInfo(data.m_Inputs[0]).GetShape(); - const TensorShape& inShape1 = GetTensorInfo(data.m_Inputs[1]).GetShape(); - const TensorShape& outShape = GetTensorInfo(data.m_Outputs[0]).GetShape(); - - const float* inData0 = GetInputTensorDataFloat(0, data); - const float* inData1 = GetInputTensorDataFloat(1, data); - float* outData = GetOutputTensorDataFloat(0, data); - - ArithmeticFunction(inShape0, inShape1, outShape, inData0, inData1, outData); -} - -template -void BaseUint8ArithmeticWorkload::ExecuteImpl(const char * debugString) const -{ - ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString); - - auto data = Uint8Workload::GetData(); - const TensorInfo& inputInfo0 = GetTensorInfo(data.m_Inputs[0]); - const TensorInfo& inputInfo1 = GetTensorInfo(data.m_Inputs[1]); - const TensorInfo& outputInfo = GetTensorInfo(data.m_Outputs[0]); - - auto dequant0 = Dequantize(GetInputTensorDataU8(0, data), inputInfo0); - auto dequant1 = Dequantize(GetInputTensorDataU8(1, data), inputInfo1); - - std::vector results(outputInfo.GetNumElements()); - - ArithmeticFunction(inputInfo0.GetShape(), - inputInfo1.GetShape(), - outputInfo.GetShape(), - dequant0.data(), - dequant1.data(), - results.data()); - - Quantize(GetOutputTensorDataU8(0, data), results.data(), outputInfo); -} - -} - -template class armnn::BaseFloat32ArithmeticWorkload>; -template class armnn::BaseUint8ArithmeticWorkload>; - -template class armnn::BaseFloat32ArithmeticWorkload>; -template class armnn::BaseUint8ArithmeticWorkload>; - -template class armnn::BaseFloat32ArithmeticWorkload>; -template class armnn::BaseUint8ArithmeticWorkload>; - -template class armnn::BaseFloat32ArithmeticWorkload>; -template class armnn::BaseUint8ArithmeticWorkload>; diff --git a/src/backends/reference/workloads/RefArithmeticWorkload.hpp b/src/backends/reference/workloads/RefArithmeticWorkload.hpp deleted file mode 100644 index 75606177a6..0000000000 --- a/src/backends/reference/workloads/RefArithmeticWorkload.hpp +++ /dev/null @@ -1,122 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include -#include -#include -#include - -namespace armnn -{ - -template -class RefArithmeticWorkload -{ - // Needs specialization. The default is empty on purpose. -}; - -template -class BaseFloat32ArithmeticWorkload : public Float32Workload -{ -public: - using Float32Workload::Float32Workload; - void ExecuteImpl(const char * debugString) const; -}; - -template -class RefArithmeticWorkload - : public BaseFloat32ArithmeticWorkload -{ -public: - using BaseFloat32ArithmeticWorkload::BaseFloat32ArithmeticWorkload; - - virtual void Execute() const override - { - using Parent = BaseFloat32ArithmeticWorkload; - Parent::ExecuteImpl(StringMapping::Instance().Get(DebugString)); - } -}; - -template -class BaseUint8ArithmeticWorkload : public Uint8Workload -{ -public: - using Uint8Workload::Uint8Workload; - void ExecuteImpl(const char * debugString) const; -}; - -template -class RefArithmeticWorkload - : public BaseUint8ArithmeticWorkload -{ -public: - using BaseUint8ArithmeticWorkload::BaseUint8ArithmeticWorkload; - - virtual void Execute() const override - { - using Parent = BaseUint8ArithmeticWorkload; - Parent::ExecuteImpl(StringMapping::Instance().Get(DebugString)); - } -}; - -using RefAdditionFloat32Workload = - RefArithmeticWorkload, - DataType::Float32, - AdditionQueueDescriptor, - StringMapping::RefAdditionWorkload_Execute>; - -using RefAdditionUint8Workload = - RefArithmeticWorkload, - DataType::QuantisedAsymm8, - AdditionQueueDescriptor, - StringMapping::RefAdditionWorkload_Execute>; - - -using RefSubtractionFloat32Workload = - RefArithmeticWorkload, - DataType::Float32, - SubtractionQueueDescriptor, - StringMapping::RefSubtractionWorkload_Execute>; - -using RefSubtractionUint8Workload = - RefArithmeticWorkload, - DataType::QuantisedAsymm8, - SubtractionQueueDescriptor, - StringMapping::RefSubtractionWorkload_Execute>; - -using RefMultiplicationFloat32Workload = - RefArithmeticWorkload, - DataType::Float32, - MultiplicationQueueDescriptor, - StringMapping::RefMultiplicationWorkload_Execute>; - -using RefMultiplicationUint8Workload = - RefArithmeticWorkload, - DataType::QuantisedAsymm8, - MultiplicationQueueDescriptor, - StringMapping::RefMultiplicationWorkload_Execute>; - -using RefDivisionFloat32Workload = - RefArithmeticWorkload, - DataType::Float32, - DivisionQueueDescriptor, - StringMapping::RefDivisionWorkload_Execute>; - -using RefDivisionUint8Workload = - RefArithmeticWorkload, - DataType::QuantisedAsymm8, - DivisionQueueDescriptor, - StringMapping::RefDivisionWorkload_Execute>; - -} // armnn diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp new file mode 100644 index 0000000000..8e312a7dd1 --- /dev/null +++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp @@ -0,0 +1,69 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "RefElementwiseWorkload.hpp" +#include "ElementwiseFunction.hpp" +#include "RefWorkloadUtils.hpp" +#include "Profiling.hpp" +#include + +namespace armnn +{ + +template +void BaseFloat32ElementwiseWorkload::ExecuteImpl(const char * debugString) const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString); + + auto data = Float32Workload::GetData(); + const TensorShape& inShape0 = GetTensorInfo(data.m_Inputs[0]).GetShape(); + const TensorShape& inShape1 = GetTensorInfo(data.m_Inputs[1]).GetShape(); + const TensorShape& outShape = GetTensorInfo(data.m_Outputs[0]).GetShape(); + + const float* inData0 = GetInputTensorDataFloat(0, data); + const float* inData1 = GetInputTensorDataFloat(1, data); + float* outData = GetOutputTensorDataFloat(0, data); + + ElementwiseFunction(inShape0, inShape1, outShape, inData0, inData1, outData); +} + +template +void BaseUint8ElementwiseWorkload::ExecuteImpl(const char * debugString) const +{ + ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, debugString); + + auto data = Uint8Workload::GetData(); + const TensorInfo& inputInfo0 = GetTensorInfo(data.m_Inputs[0]); + const TensorInfo& inputInfo1 = GetTensorInfo(data.m_Inputs[1]); + const TensorInfo& outputInfo = GetTensorInfo(data.m_Outputs[0]); + + auto dequant0 = Dequantize(GetInputTensorDataU8(0, data), inputInfo0); + auto dequant1 = Dequantize(GetInputTensorDataU8(1, data), inputInfo1); + + std::vector results(outputInfo.GetNumElements()); + + ElementwiseFunction(inputInfo0.GetShape(), + inputInfo1.GetShape(), + outputInfo.GetShape(), + dequant0.data(), + dequant1.data(), + results.data()); + + Quantize(GetOutputTensorDataU8(0, data), results.data(), outputInfo); +} + +} + +template class armnn::BaseFloat32ElementwiseWorkload>; +template class armnn::BaseUint8ElementwiseWorkload>; + +template class armnn::BaseFloat32ElementwiseWorkload>; +template class armnn::BaseUint8ElementwiseWorkload>; + +template class armnn::BaseFloat32ElementwiseWorkload>; +template class armnn::BaseUint8ElementwiseWorkload>; + +template class armnn::BaseFloat32ElementwiseWorkload>; +template class armnn::BaseUint8ElementwiseWorkload>; diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp new file mode 100644 index 0000000000..156613a49f --- /dev/null +++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp @@ -0,0 +1,122 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include +#include +#include +#include + +namespace armnn +{ + +template +class RefElementwiseWorkload +{ + // Needs specialization. The default is empty on purpose. +}; + +template +class BaseFloat32ElementwiseWorkload : public Float32Workload +{ +public: + using Float32Workload::Float32Workload; + void ExecuteImpl(const char * debugString) const; +}; + +template +class RefElementwiseWorkload + : public BaseFloat32ElementwiseWorkload +{ +public: + using BaseFloat32ElementwiseWorkload::BaseFloat32ElementwiseWorkload; + + virtual void Execute() const override + { + using Parent = BaseFloat32ElementwiseWorkload; + Parent::ExecuteImpl(StringMapping::Instance().Get(DebugString)); + } +}; + +template +class BaseUint8ElementwiseWorkload : public Uint8Workload +{ +public: + using Uint8Workload::Uint8Workload; + void ExecuteImpl(const char * debugString) const; +}; + +template +class RefElementwiseWorkload + : public BaseUint8ElementwiseWorkload +{ +public: + using BaseUint8ElementwiseWorkload::BaseUint8ElementwiseWorkload; + + virtual void Execute() const override + { + using Parent = BaseUint8ElementwiseWorkload; + Parent::ExecuteImpl(StringMapping::Instance().Get(DebugString)); + } +}; + +using RefAdditionFloat32Workload = + RefElementwiseWorkload, + DataType::Float32, + AdditionQueueDescriptor, + StringMapping::RefAdditionWorkload_Execute>; + +using RefAdditionUint8Workload = + RefElementwiseWorkload, + DataType::QuantisedAsymm8, + AdditionQueueDescriptor, + StringMapping::RefAdditionWorkload_Execute>; + + +using RefSubtractionFloat32Workload = + RefElementwiseWorkload, + DataType::Float32, + SubtractionQueueDescriptor, + StringMapping::RefSubtractionWorkload_Execute>; + +using RefSubtractionUint8Workload = + RefElementwiseWorkload, + DataType::QuantisedAsymm8, + SubtractionQueueDescriptor, + StringMapping::RefSubtractionWorkload_Execute>; + +using RefMultiplicationFloat32Workload = + RefElementwiseWorkload, + DataType::Float32, + MultiplicationQueueDescriptor, + StringMapping::RefMultiplicationWorkload_Execute>; + +using RefMultiplicationUint8Workload = + RefElementwiseWorkload, + DataType::QuantisedAsymm8, + MultiplicationQueueDescriptor, + StringMapping::RefMultiplicationWorkload_Execute>; + +using RefDivisionFloat32Workload = + RefElementwiseWorkload, + DataType::Float32, + DivisionQueueDescriptor, + StringMapping::RefDivisionWorkload_Execute>; + +using RefDivisionUint8Workload = + RefElementwiseWorkload, + DataType::QuantisedAsymm8, + DivisionQueueDescriptor, + StringMapping::RefDivisionWorkload_Execute>; + +} // armnn diff --git a/src/backends/reference/workloads/RefWorkloads.hpp b/src/backends/reference/workloads/RefWorkloads.hpp index 20e9a9f5d3..86d86248b2 100644 --- a/src/backends/reference/workloads/RefWorkloads.hpp +++ b/src/backends/reference/workloads/RefWorkloads.hpp @@ -6,8 +6,8 @@ #pragma once #include "RefConstantUint8Workload.hpp" -#include "ArithmeticFunction.hpp" -#include "RefArithmeticWorkload.hpp" +#include "ElementwiseFunction.hpp" +#include "RefElementwiseWorkload.hpp" #include "ConvImpl.hpp" #include "RefBaseConstantWorkload.hpp" #include "RefConvolution2dUint8Workload.hpp" -- cgit v1.2.1