diff options
-rw-r--r-- | src/backends/ClLayerSupport.cpp | 9 | ||||
-rw-r--r-- | src/backends/ClLayerSupport.hpp | 5 | ||||
-rw-r--r-- | src/backends/ClWorkloadFactory.cpp | 3 | ||||
-rw-r--r-- | src/backends/ClWorkloads.hpp | 1 | ||||
-rw-r--r-- | src/backends/ClWorkloads/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/backends/ClWorkloads/ClPadWorkload.cpp | 62 | ||||
-rw-r--r-- | src/backends/ClWorkloads/ClPadWorkload.hpp | 32 | ||||
-rw-r--r-- | src/backends/ClWorkloads/backend.mk | 1 |
8 files changed, 114 insertions, 1 deletions
diff --git a/src/backends/ClLayerSupport.cpp b/src/backends/ClLayerSupport.cpp index 30a1330706..8c9ba6e3fe 100644 --- a/src/backends/ClLayerSupport.cpp +++ b/src/backends/ClLayerSupport.cpp @@ -25,6 +25,7 @@ #include "ClWorkloads/ClL2NormalizationFloatWorkload.hpp" #include "ClWorkloads/ClMultiplicationFloatWorkload.hpp" #include "ClWorkloads/ClFullyConnectedWorkload.hpp" +#include "ClWorkloads/ClPadWorkload.hpp" #include "ClWorkloads/ClPooling2dBaseWorkload.hpp" #include "ClWorkloads/ClPermuteWorkload.hpp" #include "ClWorkloads/ClNormalizationFloatWorkload.hpp" @@ -334,6 +335,14 @@ bool IsOutputSupportedCl(const TensorInfo& output, &TrueFunc<>); } +bool IsPadSupportedCl(const TensorInfo& input, + const TensorInfo& output, + const PadDescriptor& descriptor, + std::string* reasonIfUnsupported) +{ + return FORWARD_CL_LAYER_SUPPORT_FUNC(ClPadValidate(input, output, descriptor, reasonIfUnsupported)); +} + bool IsPermuteSupportedCl(const TensorInfo& input, const TensorInfo& output, const PermuteDescriptor& descriptor, diff --git a/src/backends/ClLayerSupport.hpp b/src/backends/ClLayerSupport.hpp index f5c1226e56..69c9b646f4 100644 --- a/src/backends/ClLayerSupport.hpp +++ b/src/backends/ClLayerSupport.hpp @@ -109,6 +109,11 @@ bool IsNormalizationSupportedCl(const TensorInfo& input, bool IsOutputSupportedCl(const TensorInfo& output, std::string* reasonIfUnsupported = nullptr); +bool IsPadSupportedCl(const TensorInfo& input, + const TensorInfo& output, + const PadDescriptor& descriptor, + std::string* reasonIfUnsupported = nullptr); + bool IsPermuteSupportedCl(const TensorInfo& input, const TensorInfo& output, const PermuteDescriptor& descriptor, diff --git a/src/backends/ClWorkloadFactory.cpp b/src/backends/ClWorkloadFactory.cpp index 6aee233ca4..6d7ff3d4e3 100644 --- a/src/backends/ClWorkloadFactory.cpp +++ b/src/backends/ClWorkloadFactory.cpp @@ -270,7 +270,8 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreateMean(const MeanQueueDescript std::unique_ptr<IWorkload> ClWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info); + return MakeWorkload<ClPadWorkload<armnn::DataType::Float16, armnn::DataType::Float32>, + ClPadWorkload<armnn::DataType::QuantisedAsymm8>>(descriptor, info); } void ClWorkloadFactory::Finalize() diff --git a/src/backends/ClWorkloads.hpp b/src/backends/ClWorkloads.hpp index 2bbda8a62b..272f1b03ff 100644 --- a/src/backends/ClWorkloads.hpp +++ b/src/backends/ClWorkloads.hpp @@ -25,6 +25,7 @@ #include "backends/ClWorkloads/ClMergerUint8Workload.hpp" #include "backends/ClWorkloads/ClMultiplicationFloatWorkload.hpp" #include "backends/ClWorkloads/ClNormalizationFloatWorkload.hpp" +#include "backends/ClWorkloads/ClPadWorkload.hpp" #include "backends/ClWorkloads/ClPermuteWorkload.hpp" #include "backends/ClWorkloads/ClPooling2dFloatWorkload.hpp" #include "backends/ClWorkloads/ClPooling2dUint8Workload.hpp" diff --git a/src/backends/ClWorkloads/CMakeLists.txt b/src/backends/ClWorkloads/CMakeLists.txt index ac935b5cf7..ec61d534f0 100644 --- a/src/backends/ClWorkloads/CMakeLists.txt +++ b/src/backends/ClWorkloads/CMakeLists.txt @@ -54,6 +54,8 @@ list(APPEND armnnClBackend_sources ClMultiplicationFloatWorkload.hpp ClNormalizationFloatWorkload.cpp ClNormalizationFloatWorkload.hpp + ClPadWorkload.cpp + ClPadWorkload.hpp ClPermuteWorkload.cpp ClPermuteWorkload.hpp ClPooling2dBaseWorkload.cpp diff --git a/src/backends/ClWorkloads/ClPadWorkload.cpp b/src/backends/ClWorkloads/ClPadWorkload.cpp new file mode 100644 index 0000000000..45a9d0dc44 --- /dev/null +++ b/src/backends/ClWorkloads/ClPadWorkload.cpp @@ -0,0 +1,62 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ClPadWorkload.hpp" + +#include "backends/ClTensorHandle.hpp" +#include "backends/aclCommon/ArmComputeTensorUtils.hpp" +#include "ClWorkloadUtils.hpp" +#include <arm_compute/core/Types.h> + +namespace armnn +{ +using namespace armcomputetensorutils; + +template <armnn::DataType... T> +ClPadWorkload<T...>::ClPadWorkload(const PadQueueDescriptor& descriptor, const WorkloadInfo& info) +: TypedWorkload<PadQueueDescriptor, T...>(descriptor, info) +{ + this->m_Data.ValidateInputsOutputs("ClPadWorkload", 1, 1); + + arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(this->m_Data.m_Inputs[0])->GetTensor(); + arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(this->m_Data.m_Outputs[0])->GetTensor(); + arm_compute::PaddingList padList = static_cast<arm_compute::PaddingList>(descriptor.m_Parameters.m_PadList); + + m_Layer.configure(&input, &output, padList); +} + +template <armnn::DataType... T> +void ClPadWorkload<T...>::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT_CL("ClPadWorkload_Execute"); + m_Layer.run(); +} + +bool ClPadValidate(const TensorInfo& input, + const TensorInfo& output, + const PadDescriptor& descriptor, + std::string* reasonIfUnsupported) +{ + const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input); + const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output); + arm_compute::PaddingList padList = static_cast<arm_compute::PaddingList>(descriptor.m_PadList); + + const arm_compute::Status aclStatus = arm_compute::CLPadLayer::validate(&aclInputInfo, + &aclOutputInfo, + padList); + + const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK); + if (!supported && reasonIfUnsupported) + { + *reasonIfUnsupported = aclStatus.error_description(); + } + + return supported; +} + +} // namespace armnn + +template class armnn::ClPadWorkload<armnn::DataType::Float16, armnn::DataType::Float32>; +template class armnn::ClPadWorkload<armnn::DataType::QuantisedAsymm8>; diff --git a/src/backends/ClWorkloads/ClPadWorkload.hpp b/src/backends/ClWorkloads/ClPadWorkload.hpp new file mode 100644 index 0000000000..0ec560d545 --- /dev/null +++ b/src/backends/ClWorkloads/ClPadWorkload.hpp @@ -0,0 +1,32 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "backends/WorkloadData.hpp" +#include "backends/Workload.hpp" +#include <arm_compute/runtime/CL/functions/CLPadLayer.h> + +namespace armnn { + +template <armnn::DataType... dataTypes> +class ClPadWorkload : public TypedWorkload<PadQueueDescriptor, dataTypes...> +{ +public: + ClPadWorkload(const PadQueueDescriptor& descriptor, const WorkloadInfo& info); + + void Execute() const override; + +private: + mutable arm_compute::CLPadLayer m_Layer; +}; + +bool ClPadValidate(const TensorInfo& input, + const TensorInfo& output, + const PadDescriptor& descriptor, + std::string* reasonIfUnsupported); + +} //namespace armnn + diff --git a/src/backends/ClWorkloads/backend.mk b/src/backends/ClWorkloads/backend.mk index 9a20961287..9ac5004f64 100644 --- a/src/backends/ClWorkloads/backend.mk +++ b/src/backends/ClWorkloads/backend.mk @@ -33,6 +33,7 @@ BACKEND_SOURCES := \ ClMergerUint8Workload.cpp \ ClMultiplicationFloatWorkload.cpp \ ClNormalizationFloatWorkload.cpp \ + ClPadWorkload.cpp \ ClPermuteWorkload.cpp \ ClPooling2dBaseWorkload.cpp \ ClPooling2dFloatWorkload.cpp \ |