From 0a710c4c44be908a93a318e1fbd5c3535e849293 Mon Sep 17 00:00:00 2001 From: David Beck Date: Tue, 11 Sep 2018 15:21:14 +0100 Subject: IVGCVSW-1843 : refactor ClAdditionWorkload and ClSubtractionWorkload Change-Id: I0ca9f16217f8e32bb57a49b841611f10dabf021a --- src/armnn/backends/ClLayerSupport.cpp | 4 +- src/armnn/backends/ClWorkloadFactory.cpp | 6 +- src/armnn/backends/ClWorkloads.hpp | 6 +- .../ClWorkloads/ClAdditionBaseWorkload.cpp | 64 ---------------------- .../ClWorkloads/ClAdditionBaseWorkload.hpp | 29 ---------- .../ClWorkloads/ClAdditionFloatWorkload.cpp | 22 -------- .../ClWorkloads/ClAdditionFloatWorkload.hpp | 20 ------- .../ClWorkloads/ClAdditionUint8Workload.cpp | 18 ------ .../ClWorkloads/ClAdditionUint8Workload.hpp | 20 ------- .../backends/ClWorkloads/ClAdditionWorkload.cpp | 64 ++++++++++++++++++++++ .../backends/ClWorkloads/ClAdditionWorkload.hpp | 29 ++++++++++ .../ClWorkloads/ClSubtractionBaseWorkload.cpp | 64 ---------------------- .../ClWorkloads/ClSubtractionBaseWorkload.hpp | 29 ---------- .../ClWorkloads/ClSubtractionFloatWorkload.cpp | 22 -------- .../ClWorkloads/ClSubtractionFloatWorkload.hpp | 20 ------- .../ClWorkloads/ClSubtractionUint8Workload.cpp | 18 ------ .../ClWorkloads/ClSubtractionUint8Workload.hpp | 20 ------- .../backends/ClWorkloads/ClSubtractionWorkload.cpp | 64 ++++++++++++++++++++++ .../backends/ClWorkloads/ClSubtractionWorkload.hpp | 29 ++++++++++ src/armnn/backends/test/CreateWorkloadCl.cpp | 8 +-- 20 files changed, 198 insertions(+), 358 deletions(-) delete mode 100644 src/armnn/backends/ClWorkloads/ClAdditionBaseWorkload.cpp delete mode 100644 src/armnn/backends/ClWorkloads/ClAdditionBaseWorkload.hpp delete mode 100644 src/armnn/backends/ClWorkloads/ClAdditionFloatWorkload.cpp delete mode 100644 src/armnn/backends/ClWorkloads/ClAdditionFloatWorkload.hpp delete mode 100644 src/armnn/backends/ClWorkloads/ClAdditionUint8Workload.cpp delete mode 100644 src/armnn/backends/ClWorkloads/ClAdditionUint8Workload.hpp create mode 100644 src/armnn/backends/ClWorkloads/ClAdditionWorkload.cpp create mode 100644 src/armnn/backends/ClWorkloads/ClAdditionWorkload.hpp delete mode 100644 src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.cpp delete mode 100644 src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.hpp delete mode 100644 src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.cpp delete mode 100644 src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.hpp delete mode 100644 src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.cpp delete mode 100644 src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.hpp create mode 100644 src/armnn/backends/ClWorkloads/ClSubtractionWorkload.cpp create mode 100644 src/armnn/backends/ClWorkloads/ClSubtractionWorkload.hpp (limited to 'src/armnn') diff --git a/src/armnn/backends/ClLayerSupport.cpp b/src/armnn/backends/ClLayerSupport.cpp index 3dba1ec94c..aeb2759aa1 100644 --- a/src/armnn/backends/ClLayerSupport.cpp +++ b/src/armnn/backends/ClLayerSupport.cpp @@ -14,7 +14,7 @@ #include #ifdef ARMCOMPUTECL_ENABLED -#include "ClWorkloads/ClAdditionFloatWorkload.hpp" +#include "ClWorkloads/ClAdditionWorkload.hpp" #include "ClWorkloads/ClActivationFloatWorkload.hpp" #include "ClWorkloads/ClBatchNormalizationFloatWorkload.hpp" #include "ClWorkloads/ClConvertFp16ToFp32Workload.hpp" @@ -29,7 +29,7 @@ #include "ClWorkloads/ClPermuteWorkload.hpp" #include "ClWorkloads/ClNormalizationFloatWorkload.hpp" #include "ClWorkloads/ClSoftmaxBaseWorkload.hpp" -#include "ClWorkloads/ClSubtractionFloatWorkload.hpp" +#include "ClWorkloads/ClSubtractionWorkload.hpp" #include "ClWorkloads/ClLstmFloatWorkload.hpp" #endif diff --git a/src/armnn/backends/ClWorkloadFactory.cpp b/src/armnn/backends/ClWorkloadFactory.cpp index 056a201783..217c637784 100644 --- a/src/armnn/backends/ClWorkloadFactory.cpp +++ b/src/armnn/backends/ClWorkloadFactory.cpp @@ -154,7 +154,8 @@ std::unique_ptr ClWorkloadFactory::CreateNormalization(const N std::unique_ptr ClWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload(descriptor, info); + return MakeWorkload, + ClAdditionWorkload>(descriptor, info); } std::unique_ptr ClWorkloadFactory::CreateMultiplication( @@ -172,7 +173,8 @@ std::unique_ptr ClWorkloadFactory::CreateDivision( std::unique_ptr ClWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload(descriptor, info); + return MakeWorkload, + ClSubtractionWorkload>(descriptor, info); } std::unique_ptr ClWorkloadFactory::CreateBatchNormalization( diff --git a/src/armnn/backends/ClWorkloads.hpp b/src/armnn/backends/ClWorkloads.hpp index 0800401a22..3472bca45c 100644 --- a/src/armnn/backends/ClWorkloads.hpp +++ b/src/armnn/backends/ClWorkloads.hpp @@ -6,8 +6,7 @@ #pragma once #include "backends/ClWorkloads/ClActivationFloatWorkload.hpp" #include "backends/ClWorkloads/ClActivationUint8Workload.hpp" -#include "backends/ClWorkloads/ClAdditionFloatWorkload.hpp" -#include "backends/ClWorkloads/ClAdditionUint8Workload.hpp" +#include "backends/ClWorkloads/ClAdditionWorkload.hpp" #include "backends/ClWorkloads/ClBaseConstantWorkload.hpp" #include "backends/ClWorkloads/ClBaseMergerWorkload.hpp" #include "backends/ClWorkloads/ClBatchNormalizationFloatWorkload.hpp" @@ -36,7 +35,6 @@ #include "backends/ClWorkloads/ClSoftmaxUint8Workload.hpp" #include "backends/ClWorkloads/ClSplitterFloatWorkload.hpp" #include "backends/ClWorkloads/ClSplitterUint8Workload.hpp" -#include "backends/ClWorkloads/ClSubtractionFloatWorkload.hpp" -#include "backends/ClWorkloads/ClSubtractionUint8Workload.hpp" +#include "backends/ClWorkloads/ClSubtractionWorkload.hpp" #include "backends/ClWorkloads/ClConvertFp16ToFp32Workload.hpp" #include "backends/ClWorkloads/ClConvertFp32ToFp16Workload.hpp" diff --git a/src/armnn/backends/ClWorkloads/ClAdditionBaseWorkload.cpp b/src/armnn/backends/ClWorkloads/ClAdditionBaseWorkload.cpp deleted file mode 100644 index eb14aa3891..0000000000 --- a/src/armnn/backends/ClWorkloads/ClAdditionBaseWorkload.cpp +++ /dev/null @@ -1,64 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "ClAdditionBaseWorkload.hpp" - -#include "backends/ClTensorHandle.hpp" -#include "backends/CpuTensorHandle.hpp" -#include "backends/ArmComputeTensorUtils.hpp" - -namespace armnn -{ -using namespace armcomputetensorutils; - -static constexpr arm_compute::ConvertPolicy g_AclConvertPolicy = arm_compute::ConvertPolicy::SATURATE; - -template -ClAdditionBaseWorkload::ClAdditionBaseWorkload(const AdditionQueueDescriptor& descriptor, - const WorkloadInfo& info) - : TypedWorkload(descriptor, info) -{ - this->m_Data.ValidateInputsOutputs("ClAdditionBaseWorkload", 2, 1); - - arm_compute::ICLTensor& input0 = static_cast(this->m_Data.m_Inputs[0])->GetTensor(); - arm_compute::ICLTensor& input1 = static_cast(this->m_Data.m_Inputs[1])->GetTensor(); - arm_compute::ICLTensor& output = static_cast(this->m_Data.m_Outputs[0])->GetTensor(); - m_Layer.configure(&input0, &input1, &output, g_AclConvertPolicy); -} - -template -void ClAdditionBaseWorkload::Execute() const -{ - ARMNN_SCOPED_PROFILING_EVENT_CL("ClAdditionBaseWorkload_Execute"); - m_Layer.run(); -} - -bool ClAdditionValidate(const TensorInfo& input0, - const TensorInfo& input1, - const TensorInfo& output, - std::string* reasonIfUnsupported) -{ - const arm_compute::TensorInfo aclInput0Info = BuildArmComputeTensorInfo(input0); - const arm_compute::TensorInfo aclInput1Info = BuildArmComputeTensorInfo(input1); - const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output); - - const arm_compute::Status aclStatus = arm_compute::CLArithmeticAddition::validate(&aclInput0Info, - &aclInput1Info, - &aclOutputInfo, - g_AclConvertPolicy); - - const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK); - if (!supported && reasonIfUnsupported) - { - *reasonIfUnsupported = aclStatus.error_description(); - } - - return supported; -} - -} //namespace armnn - -template class armnn::ClAdditionBaseWorkload; -template class armnn::ClAdditionBaseWorkload; diff --git a/src/armnn/backends/ClWorkloads/ClAdditionBaseWorkload.hpp b/src/armnn/backends/ClWorkloads/ClAdditionBaseWorkload.hpp deleted file mode 100644 index b3bf1fe597..0000000000 --- a/src/armnn/backends/ClWorkloads/ClAdditionBaseWorkload.hpp +++ /dev/null @@ -1,29 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include "backends/ClWorkloadUtils.hpp" - -namespace armnn -{ - -template -class ClAdditionBaseWorkload : public TypedWorkload -{ -public: - ClAdditionBaseWorkload(const AdditionQueueDescriptor& descriptor, const WorkloadInfo& info); - - void Execute() const override; - -private: - mutable arm_compute::CLArithmeticAddition m_Layer; -}; - -bool ClAdditionValidate(const TensorInfo& input0, - const TensorInfo& input1, - const TensorInfo& output, - std::string* reasonIfUnsupported); -} //namespace armnn diff --git a/src/armnn/backends/ClWorkloads/ClAdditionFloatWorkload.cpp b/src/armnn/backends/ClWorkloads/ClAdditionFloatWorkload.cpp deleted file mode 100644 index b51d8a7efd..0000000000 --- a/src/armnn/backends/ClWorkloads/ClAdditionFloatWorkload.cpp +++ /dev/null @@ -1,22 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "ClAdditionFloatWorkload.hpp" - -#include "backends/ClTensorHandle.hpp" -#include "backends/CpuTensorHandle.hpp" -#include "backends/ArmComputeTensorUtils.hpp" - -namespace armnn -{ -using namespace armcomputetensorutils; - -void ClAdditionFloatWorkload::Execute() const -{ - ARMNN_SCOPED_PROFILING_EVENT_CL("ClAdditionFloatWorkload_Execute"); - ClAdditionBaseWorkload::Execute(); -} - -} //namespace armnn diff --git a/src/armnn/backends/ClWorkloads/ClAdditionFloatWorkload.hpp b/src/armnn/backends/ClWorkloads/ClAdditionFloatWorkload.hpp deleted file mode 100644 index de33ca6ce4..0000000000 --- a/src/armnn/backends/ClWorkloads/ClAdditionFloatWorkload.hpp +++ /dev/null @@ -1,20 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include "ClAdditionBaseWorkload.hpp" - -namespace armnn -{ - -class ClAdditionFloatWorkload : public ClAdditionBaseWorkload -{ -public: - using ClAdditionBaseWorkload::ClAdditionBaseWorkload; - void Execute() const override; -}; - -} //namespace armnn diff --git a/src/armnn/backends/ClWorkloads/ClAdditionUint8Workload.cpp b/src/armnn/backends/ClWorkloads/ClAdditionUint8Workload.cpp deleted file mode 100644 index 57b9062c15..0000000000 --- a/src/armnn/backends/ClWorkloads/ClAdditionUint8Workload.cpp +++ /dev/null @@ -1,18 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "ClAdditionUint8Workload.hpp" - -namespace armnn -{ -using namespace armcomputetensorutils; - -void ClAdditionUint8Workload::Execute() const -{ - ARMNN_SCOPED_PROFILING_EVENT_CL("ClAdditionUint8Workload_Execute"); - ClAdditionBaseWorkload::Execute(); -} - -} //namespace armnn diff --git a/src/armnn/backends/ClWorkloads/ClAdditionUint8Workload.hpp b/src/armnn/backends/ClWorkloads/ClAdditionUint8Workload.hpp deleted file mode 100644 index d127e7e5c3..0000000000 --- a/src/armnn/backends/ClWorkloads/ClAdditionUint8Workload.hpp +++ /dev/null @@ -1,20 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include "ClAdditionBaseWorkload.hpp" - -namespace armnn -{ - -class ClAdditionUint8Workload : public ClAdditionBaseWorkload -{ -public: - using ClAdditionBaseWorkload::ClAdditionBaseWorkload; - void Execute() const override; -}; - -} //namespace armnn diff --git a/src/armnn/backends/ClWorkloads/ClAdditionWorkload.cpp b/src/armnn/backends/ClWorkloads/ClAdditionWorkload.cpp new file mode 100644 index 0000000000..0bba327bef --- /dev/null +++ b/src/armnn/backends/ClWorkloads/ClAdditionWorkload.cpp @@ -0,0 +1,64 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ClAdditionWorkload.hpp" + +#include "backends/ClTensorHandle.hpp" +#include "backends/CpuTensorHandle.hpp" +#include "backends/ArmComputeTensorUtils.hpp" + +namespace armnn +{ +using namespace armcomputetensorutils; + +static constexpr arm_compute::ConvertPolicy g_AclConvertPolicy = arm_compute::ConvertPolicy::SATURATE; + +template +ClAdditionWorkload::ClAdditionWorkload(const AdditionQueueDescriptor& descriptor, + const WorkloadInfo& info) + : TypedWorkload(descriptor, info) +{ + this->m_Data.ValidateInputsOutputs("ClAdditionWorkload", 2, 1); + + arm_compute::ICLTensor& input0 = static_cast(this->m_Data.m_Inputs[0])->GetTensor(); + arm_compute::ICLTensor& input1 = static_cast(this->m_Data.m_Inputs[1])->GetTensor(); + arm_compute::ICLTensor& output = static_cast(this->m_Data.m_Outputs[0])->GetTensor(); + m_Layer.configure(&input0, &input1, &output, g_AclConvertPolicy); +} + +template +void ClAdditionWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT_CL("ClAdditionWorkload_Execute"); + m_Layer.run(); +} + +bool ClAdditionValidate(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + std::string* reasonIfUnsupported) +{ + const arm_compute::TensorInfo aclInput0Info = BuildArmComputeTensorInfo(input0); + const arm_compute::TensorInfo aclInput1Info = BuildArmComputeTensorInfo(input1); + const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output); + + const arm_compute::Status aclStatus = arm_compute::CLArithmeticAddition::validate(&aclInput0Info, + &aclInput1Info, + &aclOutputInfo, + g_AclConvertPolicy); + + const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK); + if (!supported && reasonIfUnsupported) + { + *reasonIfUnsupported = aclStatus.error_description(); + } + + return supported; +} + +} //namespace armnn + +template class armnn::ClAdditionWorkload; +template class armnn::ClAdditionWorkload; diff --git a/src/armnn/backends/ClWorkloads/ClAdditionWorkload.hpp b/src/armnn/backends/ClWorkloads/ClAdditionWorkload.hpp new file mode 100644 index 0000000000..8af8f23788 --- /dev/null +++ b/src/armnn/backends/ClWorkloads/ClAdditionWorkload.hpp @@ -0,0 +1,29 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "backends/ClWorkloadUtils.hpp" + +namespace armnn +{ + +template +class ClAdditionWorkload : public TypedWorkload +{ +public: + ClAdditionWorkload(const AdditionQueueDescriptor& descriptor, const WorkloadInfo& info); + + void Execute() const override; + +private: + mutable arm_compute::CLArithmeticAddition m_Layer; +}; + +bool ClAdditionValidate(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + std::string* reasonIfUnsupported); +} //namespace armnn diff --git a/src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.cpp b/src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.cpp deleted file mode 100644 index 2145ed4a2a..0000000000 --- a/src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.cpp +++ /dev/null @@ -1,64 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "ClSubtractionBaseWorkload.hpp" - -#include "backends/ClTensorHandle.hpp" -#include "backends/CpuTensorHandle.hpp" -#include "backends/ArmComputeTensorUtils.hpp" - -namespace armnn -{ -using namespace armcomputetensorutils; - -static constexpr arm_compute::ConvertPolicy g_AclConvertPolicy = arm_compute::ConvertPolicy::SATURATE; - -template -ClSubtractionBaseWorkload::ClSubtractionBaseWorkload(const SubtractionQueueDescriptor& descriptor, - const WorkloadInfo& info) - : TypedWorkload(descriptor, info) -{ - this->m_Data.ValidateInputsOutputs("ClSubtractionBaseWorkload", 2, 1); - - arm_compute::ICLTensor& input0 = static_cast(this->m_Data.m_Inputs[0])->GetTensor(); - arm_compute::ICLTensor& input1 = static_cast(this->m_Data.m_Inputs[1])->GetTensor(); - arm_compute::ICLTensor& output = static_cast(this->m_Data.m_Outputs[0])->GetTensor(); - m_Layer.configure(&input0, &input1, &output, g_AclConvertPolicy); -} - -template -void ClSubtractionBaseWorkload::Execute() const -{ - ARMNN_SCOPED_PROFILING_EVENT_CL("ClSubtractionBaseWorkload_Execute"); - m_Layer.run(); -} - -bool ClSubtractionValidate(const TensorInfo& input0, - const TensorInfo& input1, - const TensorInfo& output, - std::string* reasonIfUnsupported) -{ - const arm_compute::TensorInfo aclInput0Info = BuildArmComputeTensorInfo(input0); - const arm_compute::TensorInfo aclInput1Info = BuildArmComputeTensorInfo(input1); - const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output); - - const arm_compute::Status aclStatus = arm_compute::CLArithmeticSubtraction::validate(&aclInput0Info, - &aclInput1Info, - &aclOutputInfo, - g_AclConvertPolicy); - - const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK); - if (!supported && reasonIfUnsupported) - { - *reasonIfUnsupported = aclStatus.error_description(); - } - - return supported; -} - -} //namespace armnn - -template class armnn::ClSubtractionBaseWorkload; -template class armnn::ClSubtractionBaseWorkload; diff --git a/src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.hpp b/src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.hpp deleted file mode 100644 index e4595d405a..0000000000 --- a/src/armnn/backends/ClWorkloads/ClSubtractionBaseWorkload.hpp +++ /dev/null @@ -1,29 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include "backends/ClWorkloadUtils.hpp" - -namespace armnn -{ - -template -class ClSubtractionBaseWorkload : public TypedWorkload -{ -public: - ClSubtractionBaseWorkload(const SubtractionQueueDescriptor& descriptor, const WorkloadInfo& info); - - void Execute() const override; - -private: - mutable arm_compute::CLArithmeticSubtraction m_Layer; -}; - -bool ClSubtractionValidate(const TensorInfo& input0, - const TensorInfo& input1, - const TensorInfo& output, - std::string* reasonIfUnsupported); -} //namespace armnn diff --git a/src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.cpp b/src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.cpp deleted file mode 100644 index 3321e20100..0000000000 --- a/src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.cpp +++ /dev/null @@ -1,22 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "ClSubtractionFloatWorkload.hpp" - -#include "backends/ClTensorHandle.hpp" -#include "backends/CpuTensorHandle.hpp" -#include "backends/ArmComputeTensorUtils.hpp" - -namespace armnn -{ -using namespace armcomputetensorutils; - -void ClSubtractionFloatWorkload::Execute() const -{ - ARMNN_SCOPED_PROFILING_EVENT_CL("ClSubtractionFloatWorkload_Execute"); - ClSubtractionBaseWorkload::Execute(); -} - -} //namespace armnn diff --git a/src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.hpp b/src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.hpp deleted file mode 100644 index 34a5e40983..0000000000 --- a/src/armnn/backends/ClWorkloads/ClSubtractionFloatWorkload.hpp +++ /dev/null @@ -1,20 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include "ClSubtractionBaseWorkload.hpp" - -namespace armnn -{ - -class ClSubtractionFloatWorkload : public ClSubtractionBaseWorkload -{ -public: - using ClSubtractionBaseWorkload::ClSubtractionBaseWorkload; - void Execute() const override; -}; - -} //namespace armnn diff --git a/src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.cpp b/src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.cpp deleted file mode 100644 index 966068d648..0000000000 --- a/src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.cpp +++ /dev/null @@ -1,18 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#include "ClSubtractionUint8Workload.hpp" - -namespace armnn -{ -using namespace armcomputetensorutils; - -void ClSubtractionUint8Workload::Execute() const -{ - ARMNN_SCOPED_PROFILING_EVENT_CL("ClSubtractionUint8Workload_Execute"); - ClSubtractionBaseWorkload::Execute(); -} - -} //namespace armnn diff --git a/src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.hpp b/src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.hpp deleted file mode 100644 index 15b2059615..0000000000 --- a/src/armnn/backends/ClWorkloads/ClSubtractionUint8Workload.hpp +++ /dev/null @@ -1,20 +0,0 @@ -// -// Copyright © 2017 Arm Ltd. All rights reserved. -// SPDX-License-Identifier: MIT -// - -#pragma once - -#include "ClSubtractionBaseWorkload.hpp" - -namespace armnn -{ - -class ClSubtractionUint8Workload : public ClSubtractionBaseWorkload -{ -public: - using ClSubtractionBaseWorkload::ClSubtractionBaseWorkload; - void Execute() const override; -}; - -} //namespace armnn diff --git a/src/armnn/backends/ClWorkloads/ClSubtractionWorkload.cpp b/src/armnn/backends/ClWorkloads/ClSubtractionWorkload.cpp new file mode 100644 index 0000000000..ec8bfc6351 --- /dev/null +++ b/src/armnn/backends/ClWorkloads/ClSubtractionWorkload.cpp @@ -0,0 +1,64 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ClSubtractionWorkload.hpp" + +#include "backends/ClTensorHandle.hpp" +#include "backends/CpuTensorHandle.hpp" +#include "backends/ArmComputeTensorUtils.hpp" + +namespace armnn +{ +using namespace armcomputetensorutils; + +static constexpr arm_compute::ConvertPolicy g_AclConvertPolicy = arm_compute::ConvertPolicy::SATURATE; + +template +ClSubtractionWorkload::ClSubtractionWorkload(const SubtractionQueueDescriptor& descriptor, + const WorkloadInfo& info) + : TypedWorkload(descriptor, info) +{ + this->m_Data.ValidateInputsOutputs("ClSubtractionWorkload", 2, 1); + + arm_compute::ICLTensor& input0 = static_cast(this->m_Data.m_Inputs[0])->GetTensor(); + arm_compute::ICLTensor& input1 = static_cast(this->m_Data.m_Inputs[1])->GetTensor(); + arm_compute::ICLTensor& output = static_cast(this->m_Data.m_Outputs[0])->GetTensor(); + m_Layer.configure(&input0, &input1, &output, g_AclConvertPolicy); +} + +template +void ClSubtractionWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT_CL("ClSubtractionWorkload_Execute"); + m_Layer.run(); +} + +bool ClSubtractionValidate(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + std::string* reasonIfUnsupported) +{ + const arm_compute::TensorInfo aclInput0Info = BuildArmComputeTensorInfo(input0); + const arm_compute::TensorInfo aclInput1Info = BuildArmComputeTensorInfo(input1); + const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output); + + const arm_compute::Status aclStatus = arm_compute::CLArithmeticSubtraction::validate(&aclInput0Info, + &aclInput1Info, + &aclOutputInfo, + g_AclConvertPolicy); + + const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK); + if (!supported && reasonIfUnsupported) + { + *reasonIfUnsupported = aclStatus.error_description(); + } + + return supported; +} + +} //namespace armnn + +template class armnn::ClSubtractionWorkload; +template class armnn::ClSubtractionWorkload; diff --git a/src/armnn/backends/ClWorkloads/ClSubtractionWorkload.hpp b/src/armnn/backends/ClWorkloads/ClSubtractionWorkload.hpp new file mode 100644 index 0000000000..422e6a7379 --- /dev/null +++ b/src/armnn/backends/ClWorkloads/ClSubtractionWorkload.hpp @@ -0,0 +1,29 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "backends/ClWorkloadUtils.hpp" + +namespace armnn +{ + +template +class ClSubtractionWorkload : public TypedWorkload +{ +public: + ClSubtractionWorkload(const SubtractionQueueDescriptor& descriptor, const WorkloadInfo& info); + + void Execute() const override; + +private: + mutable arm_compute::CLArithmeticSubtraction m_Layer; +}; + +bool ClSubtractionValidate(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + std::string* reasonIfUnsupported); +} //namespace armnn diff --git a/src/armnn/backends/test/CreateWorkloadCl.cpp b/src/armnn/backends/test/CreateWorkloadCl.cpp index 340279e619..23843bd095 100644 --- a/src/armnn/backends/test/CreateWorkloadCl.cpp +++ b/src/armnn/backends/test/CreateWorkloadCl.cpp @@ -69,7 +69,7 @@ static void ClCreateArithmethicWorkloadTest() BOOST_AUTO_TEST_CASE(CreateAdditionFloatWorkload) { - ClCreateArithmethicWorkloadTest, AdditionQueueDescriptor, AdditionLayer, armnn::DataType::Float32>(); @@ -77,7 +77,7 @@ BOOST_AUTO_TEST_CASE(CreateAdditionFloatWorkload) BOOST_AUTO_TEST_CASE(CreateAdditionFloat16Workload) { - ClCreateArithmethicWorkloadTest, AdditionQueueDescriptor, AdditionLayer, armnn::DataType::Float16>(); @@ -85,7 +85,7 @@ BOOST_AUTO_TEST_CASE(CreateAdditionFloat16Workload) BOOST_AUTO_TEST_CASE(CreateSubtractionFloatWorkload) { - ClCreateArithmethicWorkloadTest, SubtractionQueueDescriptor, SubtractionLayer, armnn::DataType::Float32>(); @@ -93,7 +93,7 @@ BOOST_AUTO_TEST_CASE(CreateSubtractionFloatWorkload) BOOST_AUTO_TEST_CASE(CreateSubtractionFloat16Workload) { - ClCreateArithmethicWorkloadTest, SubtractionQueueDescriptor, SubtractionLayer, armnn::DataType::Float16>(); -- cgit v1.2.1