aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/utils
diff options
context:
space:
mode:
authorViet-Hoa Do <viet-hoa.do@arm.com>2023-02-24 15:52:21 +0000
committerViet-Hoa Do <viet-hoa.do@arm.com>2023-03-08 15:09:25 +0000
commit37c989a58a04985dfdc21089c7dacc7e1925a4d0 (patch)
tree6e60ada38ceaf2b651cc44a481004abbb89ceae4 /arm_compute/core/utils
parent98aca0fda7f7c7c16bd2d1cf5386246ad796d9de (diff)
downloadComputeLibrary-37c989a58a04985dfdc21089c7dacc7e1925a4d0.tar.gz
Add support for arbitrary parameters for CPU Gather
* The shape of input and indices tensors, and the gather axis can be any number, as long as these are valid and the output tensor doesn't have more dimensions than the library supports. * Update the reference code to be more generic and straightforward. * Add necessary test cases. Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com> Resolves: COMPMID-5919 Change-Id: Ic7e2032777aa97ecc147f61d5388528697508ab1 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9199 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/core/utils')
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h53
1 files changed, 23 insertions, 30 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index 9e7c981814..94bd3aca03 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -1537,39 +1537,32 @@ inline TensorShape compute_pool3d_shape(const TensorShape &src, Pooling3dLayerIn
*/
inline TensorShape compute_gather_shape(const TensorShape &input_shape, const TensorShape &indices_shape, uint32_t actual_axis)
{
- ARM_COMPUTE_ERROR_ON(input_shape.num_dimensions() > 4);
- ARM_COMPUTE_ERROR_ON(actual_axis >= input_shape.num_dimensions());
- ARM_COMPUTE_ERROR_ON(indices_shape.num_dimensions() > 3);
- TensorShape output_shape = input_shape;
- if(indices_shape.num_dimensions() == 1u)
+ const auto input_num_dims = input_shape.num_dimensions();
+ const auto indices_num_dims = indices_shape.num_dimensions();
+
+ ARM_COMPUTE_ERROR_ON(actual_axis >= input_num_dims);
+ ARM_COMPUTE_ERROR_ON(input_num_dims + indices_num_dims - 1 > Coordinates::num_max_dimensions);
+
+ TensorShape output_shape;
+ size_t dim_no = 0;
+
+ for(; dim_no < actual_axis; ++dim_no)
{
- output_shape[actual_axis] = indices_shape[0];
+ output_shape.set(dim_no, input_shape[dim_no]);
}
- else
+
+ for(; dim_no < actual_axis + indices_num_dims; ++dim_no)
{
- const auto ind_num_dims
- {
- indices_shape.num_dimensions()
- };
- output_shape.shift_right(ind_num_dims - 1);
- switch(actual_axis)
- {
- case 1:
- {
- output_shape[0] = input_shape[0];
- for(size_t idx = 0; idx < ind_num_dims; ++idx)
- {
- output_shape.set(actual_axis + idx, indices_shape[idx], false);
- }
- break;
- }
- default:
- {
- // 2d and 3d indices are only supported for axis == 1
- ARM_COMPUTE_ERROR_ON(actual_axis != 1 && indices_shape.num_dimensions() > 1);
- }
- }
+ output_shape.set(dim_no, indices_shape[dim_no - actual_axis]);
+ }
+
+ for(; dim_no < input_num_dims + indices_num_dims - 1; ++dim_no)
+ {
+ output_shape.set(dim_no, input_shape[dim_no + 1 - indices_num_dims]);
}
+
+ ARM_COMPUTE_ERROR_ON(input_shape.total_size() * indices_shape.total_size() != output_shape.total_size() * input_shape[actual_axis]);
+
return output_shape;
}
} // namespace shape_calculator