diff options
Diffstat (limited to 'src/backends/cl/workloads')
-rw-r--r-- | src/backends/cl/workloads/CMakeLists.txt | 1 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClRankWorkload.hpp | 30 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClWorkloadUtils.hpp | 7 | ||||
-rw-r--r-- | src/backends/cl/workloads/ClWorkloads.hpp | 1 |
4 files changed, 39 insertions, 0 deletions
diff --git a/src/backends/cl/workloads/CMakeLists.txt b/src/backends/cl/workloads/CMakeLists.txt index 6118d9bbe1..7427ea018d 100644 --- a/src/backends/cl/workloads/CMakeLists.txt +++ b/src/backends/cl/workloads/CMakeLists.txt @@ -86,6 +86,7 @@ list(APPEND armnnClBackendWorkloads_sources ClQuantizedLstmWorkload.hpp ClQuantizeWorkload.cpp ClQuantizeWorkload.hpp + ClRankWorkload.hpp ClReshapeWorkload.cpp ClReshapeWorkload.hpp ClResizeWorkload.cpp diff --git a/src/backends/cl/workloads/ClRankWorkload.hpp b/src/backends/cl/workloads/ClRankWorkload.hpp new file mode 100644 index 0000000000..0a7bccf6c6 --- /dev/null +++ b/src/backends/cl/workloads/ClRankWorkload.hpp @@ -0,0 +1,30 @@ +// +// Copyright © 2020 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include <backendsCommon/Workload.hpp> +#include <backendsCommon/WorkloadData.hpp> + +#include "ClWorkloadUtils.hpp" + +namespace armnn +{ + +struct ClRankWorkload : public BaseWorkload<RankQueueDescriptor> +{ +public: + using BaseWorkload<RankQueueDescriptor>::BaseWorkload; + virtual void Execute() const override + { + const ClTensorHandle* clTensorHandle = PolymorphicDowncast<const ClTensorHandle*>(m_Data.m_Inputs[0]); + const int32_t rank = static_cast<int32_t>(clTensorHandle->GetShape().GetNumDimensions()); + + std::memcpy(GetOutputTensorData<void>(0, m_Data), &rank, sizeof(int32_t)); + m_Data.m_Outputs[0]->Unmap(); + } +}; + +} //namespace armnn diff --git a/src/backends/cl/workloads/ClWorkloadUtils.hpp b/src/backends/cl/workloads/ClWorkloadUtils.hpp index 89f13a577a..b0cc8b4ed5 100644 --- a/src/backends/cl/workloads/ClWorkloadUtils.hpp +++ b/src/backends/cl/workloads/ClWorkloadUtils.hpp @@ -143,4 +143,11 @@ inline void RunClFunction(arm_compute::IFunction& function, const CheckLocation& } } +template <typename DataType, typename PayloadType> +DataType* GetOutputTensorData(unsigned int idx, const PayloadType& data) +{ + ITensorHandle* tensorHandle = data.m_Outputs[idx]; + return reinterpret_cast<DataType*>(tensorHandle->Map()); +} + } //namespace armnn diff --git a/src/backends/cl/workloads/ClWorkloads.hpp b/src/backends/cl/workloads/ClWorkloads.hpp index efcccb35c3..0045e7a77f 100644 --- a/src/backends/cl/workloads/ClWorkloads.hpp +++ b/src/backends/cl/workloads/ClWorkloads.hpp @@ -43,6 +43,7 @@ #include "ClQLstmWorkload.hpp" #include "ClQuantizeWorkload.hpp" #include "ClQuantizedLstmWorkload.hpp" +#include "ClRankWorkload.hpp" #include "ClReshapeWorkload.hpp" #include "ClResizeWorkload.hpp" #include "ClRsqrtWorkload.hpp" |