diff options
Diffstat (limited to 'arm_compute/core/utils')
-rw-r--r-- | arm_compute/core/utils/misc/ShapeCalculator.h | 53 |
1 files changed, 23 insertions, 30 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index 9e7c981814..94bd3aca03 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2022 Arm Limited. + * Copyright (c) 2017-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -1537,39 +1537,32 @@ inline TensorShape compute_pool3d_shape(const TensorShape &src, Pooling3dLayerIn */ inline TensorShape compute_gather_shape(const TensorShape &input_shape, const TensorShape &indices_shape, uint32_t actual_axis) { - ARM_COMPUTE_ERROR_ON(input_shape.num_dimensions() > 4); - ARM_COMPUTE_ERROR_ON(actual_axis >= input_shape.num_dimensions()); - ARM_COMPUTE_ERROR_ON(indices_shape.num_dimensions() > 3); - TensorShape output_shape = input_shape; - if(indices_shape.num_dimensions() == 1u) + const auto input_num_dims = input_shape.num_dimensions(); + const auto indices_num_dims = indices_shape.num_dimensions(); + + ARM_COMPUTE_ERROR_ON(actual_axis >= input_num_dims); + ARM_COMPUTE_ERROR_ON(input_num_dims + indices_num_dims - 1 > Coordinates::num_max_dimensions); + + TensorShape output_shape; + size_t dim_no = 0; + + for(; dim_no < actual_axis; ++dim_no) { - output_shape[actual_axis] = indices_shape[0]; + output_shape.set(dim_no, input_shape[dim_no]); } - else + + for(; dim_no < actual_axis + indices_num_dims; ++dim_no) { - const auto ind_num_dims - { - indices_shape.num_dimensions() - }; - output_shape.shift_right(ind_num_dims - 1); - switch(actual_axis) - { - case 1: - { - output_shape[0] = input_shape[0]; - for(size_t idx = 0; idx < ind_num_dims; ++idx) - { - output_shape.set(actual_axis + idx, indices_shape[idx], false); - } - break; - } - default: - { - // 2d and 3d indices are only supported for axis == 1 - ARM_COMPUTE_ERROR_ON(actual_axis != 1 && indices_shape.num_dimensions() > 1); - } - } + output_shape.set(dim_no, indices_shape[dim_no - actual_axis]); + } + + for(; dim_no < input_num_dims + indices_num_dims - 1; ++dim_no) + { + output_shape.set(dim_no, input_shape[dim_no + 1 - indices_num_dims]); } + + ARM_COMPUTE_ERROR_ON(input_shape.total_size() * indices_shape.total_size() != output_shape.total_size() * input_shape[actual_axis]); + return output_shape; } } // namespace shape_calculator |