diff options
Diffstat (limited to 'src/backends/ClWorkloads')
-rw-r--r-- | src/backends/ClWorkloads/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/backends/ClWorkloads/ClPadWorkload.cpp | 62 | ||||
-rw-r--r-- | src/backends/ClWorkloads/ClPadWorkload.hpp | 32 | ||||
-rw-r--r-- | src/backends/ClWorkloads/backend.mk | 1 |
4 files changed, 97 insertions, 0 deletions
diff --git a/src/backends/ClWorkloads/CMakeLists.txt b/src/backends/ClWorkloads/CMakeLists.txt index ac935b5cf7..ec61d534f0 100644 --- a/src/backends/ClWorkloads/CMakeLists.txt +++ b/src/backends/ClWorkloads/CMakeLists.txt @@ -54,6 +54,8 @@ list(APPEND armnnClBackend_sources ClMultiplicationFloatWorkload.hpp ClNormalizationFloatWorkload.cpp ClNormalizationFloatWorkload.hpp + ClPadWorkload.cpp + ClPadWorkload.hpp ClPermuteWorkload.cpp ClPermuteWorkload.hpp ClPooling2dBaseWorkload.cpp diff --git a/src/backends/ClWorkloads/ClPadWorkload.cpp b/src/backends/ClWorkloads/ClPadWorkload.cpp new file mode 100644 index 0000000000..45a9d0dc44 --- /dev/null +++ b/src/backends/ClWorkloads/ClPadWorkload.cpp @@ -0,0 +1,62 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ClPadWorkload.hpp" + +#include "backends/ClTensorHandle.hpp" +#include "backends/aclCommon/ArmComputeTensorUtils.hpp" +#include "ClWorkloadUtils.hpp" +#include <arm_compute/core/Types.h> + +namespace armnn +{ +using namespace armcomputetensorutils; + +template <armnn::DataType... T> +ClPadWorkload<T...>::ClPadWorkload(const PadQueueDescriptor& descriptor, const WorkloadInfo& info) +: TypedWorkload<PadQueueDescriptor, T...>(descriptor, info) +{ + this->m_Data.ValidateInputsOutputs("ClPadWorkload", 1, 1); + + arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(this->m_Data.m_Inputs[0])->GetTensor(); + arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(this->m_Data.m_Outputs[0])->GetTensor(); + arm_compute::PaddingList padList = static_cast<arm_compute::PaddingList>(descriptor.m_Parameters.m_PadList); + + m_Layer.configure(&input, &output, padList); +} + +template <armnn::DataType... T> +void ClPadWorkload<T...>::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT_CL("ClPadWorkload_Execute"); + m_Layer.run(); +} + +bool ClPadValidate(const TensorInfo& input, + const TensorInfo& output, + const PadDescriptor& descriptor, + std::string* reasonIfUnsupported) +{ + const arm_compute::TensorInfo aclInputInfo = BuildArmComputeTensorInfo(input); + const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output); + arm_compute::PaddingList padList = static_cast<arm_compute::PaddingList>(descriptor.m_PadList); + + const arm_compute::Status aclStatus = arm_compute::CLPadLayer::validate(&aclInputInfo, + &aclOutputInfo, + padList); + + const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK); + if (!supported && reasonIfUnsupported) + { + *reasonIfUnsupported = aclStatus.error_description(); + } + + return supported; +} + +} // namespace armnn + +template class armnn::ClPadWorkload<armnn::DataType::Float16, armnn::DataType::Float32>; +template class armnn::ClPadWorkload<armnn::DataType::QuantisedAsymm8>; diff --git a/src/backends/ClWorkloads/ClPadWorkload.hpp b/src/backends/ClWorkloads/ClPadWorkload.hpp new file mode 100644 index 0000000000..0ec560d545 --- /dev/null +++ b/src/backends/ClWorkloads/ClPadWorkload.hpp @@ -0,0 +1,32 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include "backends/WorkloadData.hpp" +#include "backends/Workload.hpp" +#include <arm_compute/runtime/CL/functions/CLPadLayer.h> + +namespace armnn { + +template <armnn::DataType... dataTypes> +class ClPadWorkload : public TypedWorkload<PadQueueDescriptor, dataTypes...> +{ +public: + ClPadWorkload(const PadQueueDescriptor& descriptor, const WorkloadInfo& info); + + void Execute() const override; + +private: + mutable arm_compute::CLPadLayer m_Layer; +}; + +bool ClPadValidate(const TensorInfo& input, + const TensorInfo& output, + const PadDescriptor& descriptor, + std::string* reasonIfUnsupported); + +} //namespace armnn + diff --git a/src/backends/ClWorkloads/backend.mk b/src/backends/ClWorkloads/backend.mk index 9a20961287..9ac5004f64 100644 --- a/src/backends/ClWorkloads/backend.mk +++ b/src/backends/ClWorkloads/backend.mk @@ -33,6 +33,7 @@ BACKEND_SOURCES := \ ClMergerUint8Workload.cpp \ ClMultiplicationFloatWorkload.cpp \ ClNormalizationFloatWorkload.cpp \ + ClPadWorkload.cpp \ ClPermuteWorkload.cpp \ ClPooling2dBaseWorkload.cpp \ ClPooling2dFloatWorkload.cpp \ |