aboutsummaryrefslogtreecommitdiff
path: root/src/core/NEON/kernels/NEGatherKernel.h
diff options
context:
space:
mode:
authorViet-Hoa Do <viet-hoa.do@arm.com>2023-02-24 15:52:21 +0000
committerViet-Hoa Do <viet-hoa.do@arm.com>2023-03-08 15:09:25 +0000
commit37c989a58a04985dfdc21089c7dacc7e1925a4d0 (patch)
tree6e60ada38ceaf2b651cc44a481004abbb89ceae4 /src/core/NEON/kernels/NEGatherKernel.h
parent98aca0fda7f7c7c16bd2d1cf5386246ad796d9de (diff)
downloadComputeLibrary-37c989a58a04985dfdc21089c7dacc7e1925a4d0.tar.gz
Add support for arbitrary parameters for CPU Gather
* The shape of input and indices tensors, and the gather axis can be any number, as long as these are valid and the output tensor doesn't have more dimensions than the library supports. * Update the reference code to be more generic and straightforward. * Add necessary test cases. Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com> Resolves: COMPMID-5919 Change-Id: Ic7e2032777aa97ecc147f61d5388528697508ab1 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9199 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/core/NEON/kernels/NEGatherKernel.h')
-rw-r--r--src/core/NEON/kernels/NEGatherKernel.h28
1 files changed, 6 insertions, 22 deletions
diff --git a/src/core/NEON/kernels/NEGatherKernel.h b/src/core/NEON/kernels/NEGatherKernel.h
index 3dc0cad7be..ce69daeda7 100644
--- a/src/core/NEON/kernels/NEGatherKernel.h
+++ b/src/core/NEON/kernels/NEGatherKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2019-2022 Arm Limited.
+ * Copyright (c) 2019-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -81,27 +81,8 @@ public:
void run(const Window &window, const ThreadInfo &info) override;
private:
- /** Implementation of the gather operation for 0 axis.
- *
- * For gather on the 0 axis an element by element copy is performed.
- *
- * @param[in] window Region on which to run the kernel. (Must be a region of the window returned by window())
- * @param[in] info Info about running thread and CPU.
- */
- template <typename U>
- void gather_0_axis(const Window &window, const ThreadInfo &info);
-
- template <typename U>
- void gather_multiindices_1_axis(const Window &window, const ThreadInfo &info);
- /** Implementation of the gather operation.
- *
- * For 1<=axis a row-wise copy is taking place.
- *
- * @param[in] window Region on which to run the kernel. (Must be a region of the window returned by window())
- * @param[in] info Info about running thread and CPU.
- */
- template <typename U>
- void gather_n_axis(const Window &window, const ThreadInfo &info);
+ template <typename TIndex>
+ void gather_common(const Window &window, const ThreadInfo &info);
using kernel_ptr = void (NEGatherKernel::*)(const Window &window, const ThreadInfo &info);
@@ -110,6 +91,9 @@ private:
int _axis;
ITensor *_output;
kernel_ptr _func;
+
+ Strides _src_it_strides;
+ Strides _idx_it_strides;
};
} // namespace arm_compute
#endif /* ARM_COMPUTE_NEGATHERKERNEL_H */