diff options
Diffstat (limited to 'src/backends/cl')
-rw-r--r-- | src/backends/cl/ClWorkloadFactory.cpp | 6 | ||||
-rw-r--r-- | src/backends/cl/ClWorkloadFactory.hpp | 3 | ||||
-rw-r--r-- | src/backends/cl/test/ClLayerTests.cpp | 29 | ||||
-rw-r--r-- | src/backends/cl/workloads/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClRankWorkload.hpp | 30 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClWorkloadUtils.hpp | 7 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClWorkloads.hpp | 1 |
7 files changed, 77 insertions, 0 deletions
diff --git a/src/backends/cl/ClWorkloadFactory.cpp b/src/backends/cl/ClWorkloadFactory.cpp index 35186f286a..5a5cb89204 100644 --- a/src/backends/cl/ClWorkloadFactory.cpp +++ b/src/backends/cl/ClWorkloadFactory.cpp @@ -560,6 +560,12 @@ std::unique_ptr<IWorkload> ClWorkloadFactory::CreateQuantizedLstm(const Quantize return MakeWorkload<ClQuantizedLstmWorkload>(descriptor, info, m_CLCompileContext); } +std::unique_ptr<IWorkload> ClWorkloadFactory::CreateRank(const RankQueueDescriptor& descriptor, + const WorkloadInfo& info) const +{ + return std::make_unique<ClRankWorkload>(descriptor, info); +} + std::unique_ptr<IWorkload> ClWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor, const WorkloadInfo& info) const { diff --git a/src/backends/cl/ClWorkloadFactory.hpp b/src/backends/cl/ClWorkloadFactory.hpp index c8812cfe1b..66aea8498f 100644 --- a/src/backends/cl/ClWorkloadFactory.hpp +++ b/src/backends/cl/ClWorkloadFactory.hpp @@ -203,6 +203,9 @@ public: std::unique_ptr<IWorkload> CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor, const WorkloadInfo& info) const override; + std::unique_ptr<IWorkload> CreateRank(const RankQueueDescriptor& descriptor, + const WorkloadInfo& info) const override; + std::unique_ptr<IWorkload> CreateReshape(const ReshapeQueueDescriptor& descriptor, const WorkloadInfo& info) const override; diff --git a/src/backends/cl/test/ClLayerTests.cpp b/src/backends/cl/test/ClLayerTests.cpp index 7d40a693a2..018a62df95 100644 --- a/src/backends/cl/test/ClLayerTests.cpp +++ b/src/backends/cl/test/ClLayerTests.cpp @@ -340,6 +340,35 @@ ARMNN_AUTO_TEST_CASE_WITH_THF(Multiplication5d, Multiplication5dTest) ARMNN_AUTO_TEST_CASE_WITH_THF(BatchNormFloat32, BatchNormFloat32Test) ARMNN_AUTO_TEST_CASE_WITH_THF(BatchNormFloat32Nhwc, BatchNormFloat32NhwcTest) +// Rank +ARMNN_AUTO_TEST_CASE_WITH_THF(RankDimSize1Float16, RankDimSize1Test<DataType::Float16>) +ARMNN_AUTO_TEST_CASE_WITH_THF(RankDimSize1Float32, RankDimSize1Test<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(RankDimSize1QAsymmU8, RankDimSize1Test<DataType::QAsymmU8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(RankDimSize1Signed32, RankDimSize1Test<DataType::Signed32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(RankDimSize1QSymmS16, RankDimSize1Test<DataType::QSymmS16>) +ARMNN_AUTO_TEST_CASE_WITH_THF(RankDimSize1QAsymmS8, RankDimSize1Test<DataType::QAsymmS8>) + +ARMNN_AUTO_TEST_CASE_WITH_THF(RankDimSize2Float16, RankDimSize2Test<DataType::Float16>) +ARMNN_AUTO_TEST_CASE_WITH_THF(RankDimSize2Float32, RankDimSize2Test<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(RankDimSize2QAsymmU8, RankDimSize2Test<DataType::QAsymmU8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(RankDimSize2Signed32, RankDimSize2Test<DataType::Signed32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(RankDimSize2QSymmS16, RankDimSize2Test<DataType::QSymmS16>) +ARMNN_AUTO_TEST_CASE_WITH_THF(RankDimSize2QAsymmS8, RankDimSize2Test<DataType::QAsymmS8>) + +ARMNN_AUTO_TEST_CASE_WITH_THF(RankDimSize3Float16, RankDimSize3Test<DataType::Float16>) +ARMNN_AUTO_TEST_CASE_WITH_THF(RankDimSize3Float32, RankDimSize3Test<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(RankDimSize3QAsymmU8, RankDimSize3Test<DataType::QAsymmU8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(RankDimSize3Signed32, RankDimSize3Test<DataType::Signed32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(RankDimSize3QSymmS16, RankDimSize3Test<DataType::QSymmS16>) +ARMNN_AUTO_TEST_CASE_WITH_THF(RankDimSize3QAsymmS8, RankDimSize3Test<DataType::QAsymmS8>) + +ARMNN_AUTO_TEST_CASE_WITH_THF(RankDimSize4Float16, RankDimSize4Test<DataType::Float16>) +ARMNN_AUTO_TEST_CASE_WITH_THF(RankDimSize4Float32, RankDimSize4Test<DataType::Float32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(RankDimSize4QAsymmU8, RankDimSize4Test<DataType::QAsymmU8>) +ARMNN_AUTO_TEST_CASE_WITH_THF(RankDimSize4Signed32, RankDimSize4Test<DataType::Signed32>) +ARMNN_AUTO_TEST_CASE_WITH_THF(RankDimSize4QSymmS16, RankDimSize4Test<DataType::QSymmS16>) +ARMNN_AUTO_TEST_CASE_WITH_THF(RankDimSize4QAsymmS8, RankDimSize4Test<DataType::QAsymmS8>) + // InstanceNormalization ARMNN_AUTO_TEST_CASE_WITH_THF(InstanceNormFloat32Nchw, InstanceNormFloat32Test, DataLayout::NCHW); ARMNN_AUTO_TEST_CASE_WITH_THF(InstanceNormFloat16Nchw, InstanceNormFloat16Test, DataLayout::NCHW); diff --git a/src/backends/cl/workloads/CMakeLists.txt b/src/backends/cl/workloads/CMakeLists.txt index 6118d9bbe1..7427ea018d 100644 --- a/src/backends/cl/workloads/CMakeLists.txt +++ b/src/backends/cl/workloads/CMakeLists.txt @@ -86,6 +86,7 @@ list(APPEND armnnClBackendWorkloads_sources ClQuantizedLstmWorkload.hpp ClQuantizeWorkload.cpp ClQuantizeWorkload.hpp + ClRankWorkload.hpp ClReshapeWorkload.cpp ClReshapeWorkload.hpp ClResizeWorkload.cpp diff --git a/src/backends/cl/workloads/ClRankWorkload.hpp b/src/backends/cl/workloads/ClRankWorkload.hpp new file mode 100644 index 0000000000..0a7bccf6c6 --- /dev/null +++ b/src/backends/cl/workloads/ClRankWorkload.hpp @@ -0,0 +1,30 @@ +// +// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <backendsCommon/Workload.hpp> +#include <backendsCommon/WorkloadData.hpp> + +#include "ClWorkloadUtils.hpp" + +namespace armnn +{ + +struct ClRankWorkload : public BaseWorkload<RankQueueDescriptor> +{ +public: + using BaseWorkload<RankQueueDescriptor>::BaseWorkload; + virtual void Execute() const override + { + const ClTensorHandle* clTensorHandle = PolymorphicDowncast<const ClTensorHandle*>(m_Data.m_Inputs[0]); + const int32_t rank = static_cast<int32_t>(clTensorHandle->GetShape().GetNumDimensions()); + + std::memcpy(GetOutputTensorData<void>(0, m_Data), &rank, sizeof(int32_t)); + m_Data.m_Outputs[0]->Unmap(); + } +}; + +} //namespace armnn diff --git a/src/backends/cl/workloads/ClWorkloadUtils.hpp b/src/backends/cl/workloads/ClWorkloadUtils.hpp index 89f13a577a..b0cc8b4ed5 100644 --- a/src/backends/cl/workloads/ClWorkloadUtils.hpp +++ b/src/backends/cl/workloads/ClWorkloadUtils.hpp @@ -143,4 +143,11 @@ inline void RunClFunction(arm_compute::IFunction& function, const CheckLocation& } } +template <typename DataType, typename PayloadType> +DataType* GetOutputTensorData(unsigned int idx, const PayloadType& data) +{ + ITensorHandle* tensorHandle = data.m_Outputs[idx]; + return reinterpret_cast<DataType*>(tensorHandle->Map()); +} + } //namespace armnn diff --git a/src/backends/cl/workloads/ClWorkloads.hpp b/src/backends/cl/workloads/ClWorkloads.hpp index efcccb35c3..0045e7a77f 100644 --- a/src/backends/cl/workloads/ClWorkloads.hpp +++ b/src/backends/cl/workloads/ClWorkloads.hpp @@ -43,6 +43,7 @@ #include "ClQLstmWorkload.hpp" #include "ClQuantizeWorkload.hpp" #include "ClQuantizedLstmWorkload.hpp" +#include "ClRankWorkload.hpp" #include "ClReshapeWorkload.hpp" #include "ClResizeWorkload.hpp" #include "ClRsqrtWorkload.hpp" |