diff options
Diffstat (limited to 'src/backends/cl/workloads')
-rw-r--r-- | src/backends/cl/workloads/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClConstantWorkload.cpp | 10 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClGatherWorkload.cpp | 48 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClGatherWorkload.hpp | 28 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClWorkloads.hpp | 1 |
5 files changed, 87 insertions, 2 deletions
diff --git a/src/backends/cl/workloads/CMakeLists.txt b/src/backends/cl/workloads/CMakeLists.txt index 7d9df07ed9..6baeae0a80 100644 --- a/src/backends/cl/workloads/CMakeLists.txt +++ b/src/backends/cl/workloads/CMakeLists.txt @@ -42,6 +42,8 @@ list(APPEND armnnClBackendWorkloads_sources ClFloorFloatWorkload.hpp ClFullyConnectedWorkload.cpp ClFullyConnectedWorkload.hpp + ClGatherWorkload.cpp + ClGatherWorkload.hpp ClInstanceNormalizationWorkload.cpp ClInstanceNormalizationWorkload.hpp ClL2NormalizationFloatWorkload.cpp diff --git a/src/backends/cl/workloads/ClConstantWorkload.cpp b/src/backends/cl/workloads/ClConstantWorkload.cpp index bae7446753..d6b5c57a7e 100644 --- a/src/backends/cl/workloads/ClConstantWorkload.cpp +++ b/src/backends/cl/workloads/ClConstantWorkload.cpp @@ -19,14 +19,15 @@ arm_compute::Status ClConstantWorkloadValidate(const TensorInfo& output) { const arm_compute::TensorInfo neonOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output); - std::array<arm_compute::DataType,7> supportedTypes = { + std::array<arm_compute::DataType,8> supportedTypes = { arm_compute::DataType::F16, arm_compute::DataType::F32, arm_compute::DataType::QASYMM8, arm_compute::DataType::QASYMM8_SIGNED, arm_compute::DataType::QSYMM16, arm_compute::DataType::QSYMM8, - arm_compute::DataType::QSYMM8_PER_CHANNEL + arm_compute::DataType::QSYMM8_PER_CHANNEL, + arm_compute::DataType::S32 }; auto it = std::find(begin(supportedTypes), end(supportedTypes), neonOutputInfo.data_type()); @@ -95,6 +96,11 @@ void ClConstantWorkload::Execute() const CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int8_t>()); break; } + case arm_compute::DataType::S32: + { + CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int32_t>()); + break; + } default: { ARMNN_ASSERT_MSG(false, "Unknown data type"); diff --git a/src/backends/cl/workloads/ClGatherWorkload.cpp b/src/backends/cl/workloads/ClGatherWorkload.cpp new file mode 100644 index 0000000000..068487039b --- /dev/null +++ b/src/backends/cl/workloads/ClGatherWorkload.cpp @@ -0,0 +1,48 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ClGatherWorkload.hpp" +#include "ClWorkloadUtils.hpp" +#include <aclCommon/ArmComputeUtils.hpp> +#include <cl/ClTensorHandle.hpp> + +using namespace armnn::armcomputetensorutils; + +namespace armnn +{ +arm_compute::Status ClGatherWorkloadValidate(const TensorInfo& input, + const TensorInfo& indices, + const TensorInfo& output) +{ + const arm_compute::TensorInfo aclInput = BuildArmComputeTensorInfo(input); + const arm_compute::TensorInfo aclIndices = BuildArmComputeTensorInfo(indices); + const arm_compute::TensorInfo aclOutput = BuildArmComputeTensorInfo(output); + + int aclAxis = ComputeAclAxis(0, input); + + return arm_compute::CLGather::validate(&aclInput, &aclIndices, &aclOutput, aclAxis); +} + +ClGatherWorkload::ClGatherWorkload(const GatherQueueDescriptor& descriptor, + const WorkloadInfo& info) + : BaseWorkload<GatherQueueDescriptor>(descriptor, info) +{ + m_Data.ValidateInputsOutputs("ClGatherWorkload", 1, 1); + + arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor(); + arm_compute::ICLTensor& indices = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor(); + arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor(); + + int aclAxis = ComputeAclAxis(0, info.m_InputTensorInfos[0]); + + m_Layer.configure(&input, &indices, &output, aclAxis); +}; + +void ClGatherWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT_CL("ClGatherWorkload_Execute"); + RunClFunction(m_Layer, CHECK_LOCATION()); +} +} // namespace armnn diff --git a/src/backends/cl/workloads/ClGatherWorkload.hpp b/src/backends/cl/workloads/ClGatherWorkload.hpp new file mode 100644 index 0000000000..5dbeaade59 --- /dev/null +++ b/src/backends/cl/workloads/ClGatherWorkload.hpp @@ -0,0 +1,28 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <backendsCommon/Workload.hpp> + +#include <arm_compute/runtime/CL/functions/CLGather.h> + +namespace armnn +{ +arm_compute::Status ClGatherWorkloadValidate(const TensorInfo& input, + const TensorInfo& indices, + const TensorInfo& output); + +class ClGatherWorkload : public BaseWorkload<GatherQueueDescriptor> +{ +public: + ClGatherWorkload(const GatherQueueDescriptor& descriptor, const WorkloadInfo& info); + void Execute() const override; + +private: + mutable arm_compute::CLGather m_Layer; +}; + +} // namespace armnn diff --git a/src/backends/cl/workloads/ClWorkloads.hpp b/src/backends/cl/workloads/ClWorkloads.hpp index 1ae9a91b88..5c81079ad7 100644 --- a/src/backends/cl/workloads/ClWorkloads.hpp +++ b/src/backends/cl/workloads/ClWorkloads.hpp @@ -20,6 +20,7 @@ #include "ClExpWorkload.hpp" #include "ClFloorFloatWorkload.hpp" #include "ClFullyConnectedWorkload.hpp" +#include "ClGatherWorkload.hpp" #include "ClInstanceNormalizationWorkload.hpp" #include "ClL2NormalizationFloatWorkload.hpp" #include "ClLstmFloatWorkload.hpp" |