aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/utils
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/utils')
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h53
1 files changed, 23 insertions, 30 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index 9e7c981814..94bd3aca03 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -1537,39 +1537,32 @@ inline TensorShape compute_pool3d_shape(const TensorShape &src, Pooling3dLayerIn
*/
inline TensorShape compute_gather_shape(const TensorShape &input_shape, const TensorShape &indices_shape, uint32_t actual_axis)
{
- ARM_COMPUTE_ERROR_ON(input_shape.num_dimensions() > 4);
- ARM_COMPUTE_ERROR_ON(actual_axis >= input_shape.num_dimensions());
- ARM_COMPUTE_ERROR_ON(indices_shape.num_dimensions() > 3);
- TensorShape output_shape = input_shape;
- if(indices_shape.num_dimensions() == 1u)
+ const auto input_num_dims = input_shape.num_dimensions();
+ const auto indices_num_dims = indices_shape.num_dimensions();
+
+ ARM_COMPUTE_ERROR_ON(actual_axis >= input_num_dims);
+ ARM_COMPUTE_ERROR_ON(input_num_dims + indices_num_dims - 1 > Coordinates::num_max_dimensions);
+
+ TensorShape output_shape;
+ size_t dim_no = 0;
+
+ for(; dim_no < actual_axis; ++dim_no)
{
- output_shape[actual_axis] = indices_shape[0];
+ output_shape.set(dim_no, input_shape[dim_no]);
}
- else
+
+ for(; dim_no < actual_axis + indices_num_dims; ++dim_no)
{
- const auto ind_num_dims
- {
- indices_shape.num_dimensions()
- };
- output_shape.shift_right(ind_num_dims - 1);
- switch(actual_axis)
- {
- case 1:
- {
- output_shape[0] = input_shape[0];
- for(size_t idx = 0; idx < ind_num_dims; ++idx)
- {
- output_shape.set(actual_axis + idx, indices_shape[idx], false);
- }
- break;
- }
- default:
- {
- // 2d and 3d indices are only supported for axis == 1
- ARM_COMPUTE_ERROR_ON(actual_axis != 1 && indices_shape.num_dimensions() > 1);
- }
- }
+ output_shape.set(dim_no, indices_shape[dim_no - actual_axis]);
+ }
+
+ for(; dim_no < input_num_dims + indices_num_dims - 1; ++dim_no)
+ {
+ output_shape.set(dim_no, input_shape[dim_no + 1 - indices_num_dims]);
}
+
+ ARM_COMPUTE_ERROR_ON(input_shape.total_size() * indices_shape.total_size() != output_shape.total_size() * input_shape[actual_axis]);
+
return output_shape;
}
} // namespace shape_calculator