diff options
author | Teresa Charlin <teresa.charlinreyes@arm.com> | 2020-04-10 22:34:48 +0100 |
---|---|---|
committer | TeresaARM <teresa.charlinreyes@arm.com> | 2020-06-02 10:28:52 +0000 |
commit | 9ad2e5baaf8819c40416ab1e91e14f8b804ba723 (patch) | |
tree | 4f4b347f5e02e0261ee23d2e60634b468fa132fb | |
parent | 4a0c9b99deb88a0ec5de7997f09062686915c6cc (diff) | |
download | armnn-9ad2e5baaf8819c40416ab1e91e14f8b804ba723.tar.gz |
IVGCVSW-3844 Add CL GATHER Workload
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com>
Change-Id: I3d44c05acb40566cd4149417fca5bfd260f301e1
-rw-r--r-- | src/backends/cl/ClLayerSupport.cpp | 13 | ||||
-rw-r--r-- | src/backends/cl/ClLayerSupport.hpp | 5 | ||||
-rw-r--r-- | src/backends/cl/ClWorkloadFactory.cpp | 2 | ||||
-rw-r--r-- | src/backends/cl/backend.mk | 1 | ||||
-rw-r--r-- | src/backends/cl/test/ClLayerTests.cpp | 6 | ||||
-rw-r--r-- | src/backends/cl/workloads/CMakeLists.txt | 2 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClConstantWorkload.cpp | 10 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClGatherWorkload.cpp | 48 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClGatherWorkload.hpp | 28 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClWorkloads.hpp | 1 |
10 files changed, 113 insertions, 3 deletions
diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp index c8d3816a4c..9c0cf61f9c 100644 --- a/src/backends/cl/ClLayerSupport.cpp +++ b/src/backends/cl/ClLayerSupport.cpp @@ -34,6 +34,7 @@ #include "workloads/ClExpWorkload.hpp" #include "workloads/ClFloorFloatWorkload.hpp" #include "workloads/ClFullyConnectedWorkload.hpp" +#include "workloads/ClGatherWorkload.hpp" #include "workloads/ClInstanceNormalizationWorkload.hpp" #include "workloads/ClL2NormalizationFloatWorkload.hpp" #include "workloads/ClLstmFloatWorkload.hpp" @@ -453,6 +454,18 @@ bool ClLayerSupport::IsFullyConnectedSupported(const TensorInfo& input, descriptor); } +bool ClLayerSupport::IsGatherSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + Optional<std::string&> reasonIfUnsupported) const +{ + FORWARD_WORKLOAD_VALIDATE_FUNC(ClGatherWorkloadValidate, + reasonIfUnsupported, + input0, + input1, + output); +} + bool ClLayerSupport::IsGreaterSupported(const TensorInfo& input0, const TensorInfo& input1, const TensorInfo& output, diff --git a/src/backends/cl/ClLayerSupport.hpp b/src/backends/cl/ClLayerSupport.hpp index d785f54387..67fd230ea2 100644 --- a/src/backends/cl/ClLayerSupport.hpp +++ b/src/backends/cl/ClLayerSupport.hpp @@ -119,6 +119,11 @@ public: const FullyConnectedDescriptor& descriptor, Optional<std::string&> reasonIfUnsupported = EmptyOptional()) const override; + bool IsGatherSupported(const TensorInfo& input0, + const TensorInfo& input1, + const TensorInfo& output, + Optional<std::string&> reasonIfUnsupported) const override; + ARMNN_DEPRECATED_MSG("Use IsComparisonSupported instead") bool IsGreaterSupported(const TensorInfo& input0, const TensorInfo& input1, diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index cabc3466aa..ff66c6bada 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -302,7 +302,7 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreateFullyConnected(const FullyCo std::unique_ptr<IWorkload> ClWorkloadFactory::CreateGather(const GatherQueueDescriptor& descriptor, const WorkloadInfo& info) const { - return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info); + return MakeWorkload<ClGatherWorkload>(descriptor, info); } std::unique_ptr<IWorkload> ClWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor, diff --git a/src/backends/cl/backend.mk b/src/backends/cl/backend.mk index 740aee6d67..721a9ec26e 100644 --- a/src/backends/cl/backend.mk +++ b/src/backends/cl/backend.mk @@ -41,6 +41,7 @@ BACKEND_SOURCES := \ workloads/ClExpWorkload.cpp \ workloads/ClFloorFloatWorkload.cpp \ workloads/ClFullyConnectedWorkload.cpp \ + workloads/ClGatherWorkload.cpp \ workloads/ClInstanceNormalizationWorkload.cpp \ workloads/ClL2NormalizationFloatWorkload.cpp \ workloads/ClLstmFloatWorkload.cpp \ diff --git a/src/backends/cl/test/ClLayerTests.cpp b/src/backends/cl/test/ClLayerTests.cpp index e6492f612a..eb7dce6930 100644 --- a/src/backends/cl/test/ClLayerTests.cpp +++ b/src/backends/cl/test/ClLayerTests.cpp @@ -465,6 +465,12 @@ ARMNN_AUTO_TEST_CASE(DepthToSpaceNhwcInt16_4, DepthToSpaceTest4<DataType::QSymmS // Floor ARMNN_AUTO_TEST_CASE(SimpleFloor, SimpleFloorTest<DataType::Float32>) +// Gather +ARMNN_AUTO_TEST_CASE(Gather1dParamsFloat32, Gather1dParamsFloat32Test) +ARMNN_AUTO_TEST_CASE(Gather1dParamsUint8, Gather1dParamsUint8Test) +ARMNN_AUTO_TEST_CASE(GatherMultiDimParamsFloat32, GatherMultiDimParamsFloat32Test) +ARMNN_AUTO_TEST_CASE(GatherMultiDimParamsUint8, GatherMultiDimParamsUint8Test) + // Reshape ARMNN_AUTO_TEST_CASE(SimpleReshapeFloat32, SimpleReshapeTest<DataType::Float32>) ARMNN_AUTO_TEST_CASE(SimpleReshapeInt8, SimpleReshapeTest<DataType::QAsymmS8>) diff --git a/src/backends/cl/workloads/CMakeLists.txt b/src/backends/cl/workloads/CMakeLists.txt index 7d9df07ed9..6baeae0a80 100644 --- a/src/backends/cl/workloads/CMakeLists.txt +++ b/src/backends/cl/workloads/CMakeLists.txt @@ -42,6 +42,8 @@ list(APPEND armnnClBackendWorkloads_sources ClFloorFloatWorkload.hpp ClFullyConnectedWorkload.cpp ClFullyConnectedWorkload.hpp + ClGatherWorkload.cpp + ClGatherWorkload.hpp ClInstanceNormalizationWorkload.cpp ClInstanceNormalizationWorkload.hpp ClL2NormalizationFloatWorkload.cpp diff --git a/src/backends/cl/workloads/ClConstantWorkload.cpp b/src/backends/cl/workloads/ClConstantWorkload.cpp index bae7446753..d6b5c57a7e 100644 --- a/src/backends/cl/workloads/ClConstantWorkload.cpp +++ b/src/backends/cl/workloads/ClConstantWorkload.cpp @@ -19,14 +19,15 @@ arm_compute::Status ClConstantWorkloadValidate(const TensorInfo& output) { const arm_compute::TensorInfo neonOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output); - std::array<arm_compute::DataType,7> supportedTypes = { + std::array<arm_compute::DataType,8> supportedTypes = { arm_compute::DataType::F16, arm_compute::DataType::F32, arm_compute::DataType::QASYMM8, arm_compute::DataType::QASYMM8_SIGNED, arm_compute::DataType::QSYMM16, arm_compute::DataType::QSYMM8, - arm_compute::DataType::QSYMM8_PER_CHANNEL + arm_compute::DataType::QSYMM8_PER_CHANNEL, + arm_compute::DataType::S32 }; auto it = std::find(begin(supportedTypes), end(supportedTypes), neonOutputInfo.data_type()); @@ -95,6 +96,11 @@ void ClConstantWorkload::Execute() const CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int8_t>()); break; } + case arm_compute::DataType::S32: + { + CopyArmComputeClTensorData(output, data.m_LayerOutput->GetConstTensor<int32_t>()); + break; + } default: { ARMNN_ASSERT_MSG(false, "Unknown data type"); diff --git a/src/backends/cl/workloads/ClGatherWorkload.cpp b/src/backends/cl/workloads/ClGatherWorkload.cpp new file mode 100644 index 0000000000..068487039b --- /dev/null +++ b/src/backends/cl/workloads/ClGatherWorkload.cpp @@ -0,0 +1,48 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "ClGatherWorkload.hpp" +#include "ClWorkloadUtils.hpp" +#include <aclCommon/ArmComputeUtils.hpp> +#include <cl/ClTensorHandle.hpp> + +using namespace armnn::armcomputetensorutils; + +namespace armnn +{ +arm_compute::Status ClGatherWorkloadValidate(const TensorInfo& input, + const TensorInfo& indices, + const TensorInfo& output) +{ + const arm_compute::TensorInfo aclInput = BuildArmComputeTensorInfo(input); + const arm_compute::TensorInfo aclIndices = BuildArmComputeTensorInfo(indices); + const arm_compute::TensorInfo aclOutput = BuildArmComputeTensorInfo(output); + + int aclAxis = ComputeAclAxis(0, input); + + return arm_compute::CLGather::validate(&aclInput, &aclIndices, &aclOutput, aclAxis); +} + +ClGatherWorkload::ClGatherWorkload(const GatherQueueDescriptor& descriptor, + const WorkloadInfo& info) + : BaseWorkload<GatherQueueDescriptor>(descriptor, info) +{ + m_Data.ValidateInputsOutputs("ClGatherWorkload", 1, 1); + + arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor(); + arm_compute::ICLTensor& indices = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor(); + arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor(); + + int aclAxis = ComputeAclAxis(0, info.m_InputTensorInfos[0]); + + m_Layer.configure(&input, &indices, &output, aclAxis); +}; + +void ClGatherWorkload::Execute() const +{ + ARMNN_SCOPED_PROFILING_EVENT_CL("ClGatherWorkload_Execute"); + RunClFunction(m_Layer, CHECK_LOCATION()); +} +} // namespace armnn diff --git a/src/backends/cl/workloads/ClGatherWorkload.hpp b/src/backends/cl/workloads/ClGatherWorkload.hpp new file mode 100644 index 0000000000..5dbeaade59 --- /dev/null +++ b/src/backends/cl/workloads/ClGatherWorkload.hpp @@ -0,0 +1,28 @@ +// +// Copyright © 2020 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <backendsCommon/Workload.hpp> + +#include <arm_compute/runtime/CL/functions/CLGather.h> + +namespace armnn +{ +arm_compute::Status ClGatherWorkloadValidate(const TensorInfo& input, + const TensorInfo& indices, + const TensorInfo& output); + +class ClGatherWorkload : public BaseWorkload<GatherQueueDescriptor> +{ +public: + ClGatherWorkload(const GatherQueueDescriptor& descriptor, const WorkloadInfo& info); + void Execute() const override; + +private: + mutable arm_compute::CLGather m_Layer; +}; + +} // namespace armnn diff --git a/src/backends/cl/workloads/ClWorkloads.hpp b/src/backends/cl/workloads/ClWorkloads.hpp index 1ae9a91b88..5c81079ad7 100644 --- a/src/backends/cl/workloads/ClWorkloads.hpp +++ b/src/backends/cl/workloads/ClWorkloads.hpp @@ -20,6 +20,7 @@ #include "ClExpWorkload.hpp" #include "ClFloorFloatWorkload.hpp" #include "ClFullyConnectedWorkload.hpp" +#include "ClGatherWorkload.hpp" #include "ClInstanceNormalizationWorkload.hpp" #include "ClL2NormalizationFloatWorkload.hpp" #include "ClLstmFloatWorkload.hpp" |