aboutsummaryrefslogtreecommitdiff
path: root/src/backends/cl/workloads/ClGatherWorkload.cpp
diff options
context:
space:
mode:
authorTeresa Charlin <teresa.charlinreyes@arm.com>2020-04-10 22:34:48 +0100
committerTeresaARM <teresa.charlinreyes@arm.com>2020-06-02 10:28:52 +0000
commit9ad2e5baaf8819c40416ab1e91e14f8b804ba723 (patch)
tree4f4b347f5e02e0261ee23d2e60634b468fa132fb /src/backends/cl/workloads/ClGatherWorkload.cpp
parent4a0c9b99deb88a0ec5de7997f09062686915c6cc (diff)
downloadarmnn-9ad2e5baaf8819c40416ab1e91e14f8b804ba723.tar.gz
IVGCVSW-3844 Add CL GATHER Workload
Signed-off-by: Teresa Charlin <teresa.charlinreyes@arm.com> Change-Id: I3d44c05acb40566cd4149417fca5bfd260f301e1
Diffstat (limited to 'src/backends/cl/workloads/ClGatherWorkload.cpp')
-rw-r--r--src/backends/cl/workloads/ClGatherWorkload.cpp48
1 files changed, 48 insertions, 0 deletions
diff --git a/src/backends/cl/workloads/ClGatherWorkload.cpp b/src/backends/cl/workloads/ClGatherWorkload.cpp
new file mode 100644
index 0000000000..068487039b
--- /dev/null
+++ b/src/backends/cl/workloads/ClGatherWorkload.cpp
@@ -0,0 +1,48 @@
+//
+// Copyright © 2020 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "ClGatherWorkload.hpp"
+#include "ClWorkloadUtils.hpp"
+#include <aclCommon/ArmComputeUtils.hpp>
+#include <cl/ClTensorHandle.hpp>
+
+using namespace armnn::armcomputetensorutils;
+
+namespace armnn
+{
+arm_compute::Status ClGatherWorkloadValidate(const TensorInfo& input,
+ const TensorInfo& indices,
+ const TensorInfo& output)
+{
+ const arm_compute::TensorInfo aclInput = BuildArmComputeTensorInfo(input);
+ const arm_compute::TensorInfo aclIndices = BuildArmComputeTensorInfo(indices);
+ const arm_compute::TensorInfo aclOutput = BuildArmComputeTensorInfo(output);
+
+ int aclAxis = ComputeAclAxis(0, input);
+
+ return arm_compute::CLGather::validate(&aclInput, &aclIndices, &aclOutput, aclAxis);
+}
+
+ClGatherWorkload::ClGatherWorkload(const GatherQueueDescriptor& descriptor,
+ const WorkloadInfo& info)
+ : BaseWorkload<GatherQueueDescriptor>(descriptor, info)
+{
+ m_Data.ValidateInputsOutputs("ClGatherWorkload", 1, 1);
+
+ arm_compute::ICLTensor& input = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
+ arm_compute::ICLTensor& indices = static_cast<IClTensorHandle*>(m_Data.m_Inputs[1])->GetTensor();
+ arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
+
+ int aclAxis = ComputeAclAxis(0, info.m_InputTensorInfos[0]);
+
+ m_Layer.configure(&input, &indices, &output, aclAxis);
+};
+
+void ClGatherWorkload::Execute() const
+{
+ ARMNN_SCOPED_PROFILING_EVENT_CL("ClGatherWorkload_Execute");
+ RunClFunction(m_Layer, CHECK_LOCATION());
+}
+} // namespace armnn