diff options
Diffstat (limited to 'arm_compute/core')
-rw-r--r-- | arm_compute/core/Helpers.h | 23 | ||||
-rw-r--r-- | arm_compute/core/Helpers.inl | 22 | ||||
-rw-r--r-- | arm_compute/core/utils/misc/ShapeCalculator.h | 53 |
3 files changed, 61 insertions, 37 deletions
diff --git a/arm_compute/core/Helpers.h b/arm_compute/core/Helpers.h index fd6e94c079..f19e1e12e0 100644 --- a/arm_compute/core/Helpers.h +++ b/arm_compute/core/Helpers.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2021 Arm Limited. + * Copyright (c) 2016-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -55,6 +55,16 @@ public: */ Iterator(const ITensor *tensor, const Window &window); + /** Create a container iterator for the tensor with the specified number of dimensions, stride, buffer pointer and window. + * + * @param[in] num_dims The number of dimensions. + * @param[in] strides The strides in bytes. + * @param[in] buffer The data buffer. + * @param[in] offset The offset in bytes from the beginning of the buffer to the first element of the tensor. + * @param[in] window The window which will be used to iterate over the tensor. + */ + Iterator(size_t num_dims, const Strides &strides, uint8_t *buffer, size_t offset, const Window &window); + /** Increment the iterator along the specified dimension of the step value associated to the dimension. * * @warning It is the caller's responsibility to call increment(dimension+1) when reaching the end of a dimension, the iterator will not check for overflow. @@ -86,6 +96,17 @@ public: void reset(size_t dimension); private: + + /** Initialize a container iterator for the tensor with the specified number of dimensions, stride, buffer pointer and window. + * + * @param[in] num_dims The number of dimensions. + * @param[in] strides The strides in bytes. + * @param[in] buffer The data buffer. + * @param[in] offset The offset in bytes from the beginning of the buffer to the first element of the tensor. + * @param[in] window The window which will be used to iterate over the tensor. + */ + void initialize(size_t num_dims, const Strides &strides, uint8_t *buffer, size_t offset, const Window &window); + uint8_t *_ptr; class Dimension diff --git a/arm_compute/core/Helpers.inl b/arm_compute/core/Helpers.inl index a910521f94..ff902bba20 100644 --- a/arm_compute/core/Helpers.inl +++ b/arm_compute/core/Helpers.inl @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2021 Arm Limited. + * Copyright (c) 2016-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -98,13 +98,23 @@ inline Iterator::Iterator(const ITensor *tensor, const Window &win) ARM_COMPUTE_ERROR_ON(tensor == nullptr); ARM_COMPUTE_ERROR_ON(tensor->info() == nullptr); - const ITensorInfo *info = tensor->info(); - const Strides &strides = info->strides_in_bytes(); + initialize(tensor->info()->num_dimensions(), tensor->info()->strides_in_bytes(), tensor->buffer(), tensor->info()->offset_first_element_in_bytes(), win); +} + +inline Iterator::Iterator(size_t num_dims, const Strides &strides, uint8_t *buffer, size_t offset, const Window &win) + : Iterator() +{ + initialize(num_dims, strides, buffer, offset, win); +} + +inline void Iterator::initialize(size_t num_dims, const Strides &strides, uint8_t *buffer, size_t offset, const Window &win) +{ + ARM_COMPUTE_ERROR_ON(buffer == nullptr); - _ptr = tensor->buffer() + info->offset_first_element_in_bytes(); + _ptr = buffer + offset; //Initialize the stride for each dimension and calculate the position of the first element of the iteration: - for(unsigned int n = 0; n < info->num_dimensions(); ++n) + for(unsigned int n = 0; n < num_dims; ++n) { _dims[n]._stride = win[n].step() * strides[n]; std::get<0>(_dims)._dim_start += static_cast<size_t>(strides[n]) * win[n].start(); @@ -116,7 +126,7 @@ inline Iterator::Iterator(const ITensor *tensor, const Window &win) _dims[n]._dim_start = std::get<0>(_dims)._dim_start; } - ARM_COMPUTE_ERROR_ON_WINDOW_DIMENSIONS_GTE(win, info->num_dimensions()); + ARM_COMPUTE_ERROR_ON_WINDOW_DIMENSIONS_GTE(win, num_dims); } inline void Iterator::increment(const size_t dimension) diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index 9e7c981814..94bd3aca03 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2022 Arm Limited. + * Copyright (c) 2017-2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -1537,39 +1537,32 @@ inline TensorShape compute_pool3d_shape(const TensorShape &src, Pooling3dLayerIn */ inline TensorShape compute_gather_shape(const TensorShape &input_shape, const TensorShape &indices_shape, uint32_t actual_axis) { - ARM_COMPUTE_ERROR_ON(input_shape.num_dimensions() > 4); - ARM_COMPUTE_ERROR_ON(actual_axis >= input_shape.num_dimensions()); - ARM_COMPUTE_ERROR_ON(indices_shape.num_dimensions() > 3); - TensorShape output_shape = input_shape; - if(indices_shape.num_dimensions() == 1u) + const auto input_num_dims = input_shape.num_dimensions(); + const auto indices_num_dims = indices_shape.num_dimensions(); + + ARM_COMPUTE_ERROR_ON(actual_axis >= input_num_dims); + ARM_COMPUTE_ERROR_ON(input_num_dims + indices_num_dims - 1 > Coordinates::num_max_dimensions); + + TensorShape output_shape; + size_t dim_no = 0; + + for(; dim_no < actual_axis; ++dim_no) { - output_shape[actual_axis] = indices_shape[0]; + output_shape.set(dim_no, input_shape[dim_no]); } - else + + for(; dim_no < actual_axis + indices_num_dims; ++dim_no) { - const auto ind_num_dims - { - indices_shape.num_dimensions() - }; - output_shape.shift_right(ind_num_dims - 1); - switch(actual_axis) - { - case 1: - { - output_shape[0] = input_shape[0]; - for(size_t idx = 0; idx < ind_num_dims; ++idx) - { - output_shape.set(actual_axis + idx, indices_shape[idx], false); - } - break; - } - default: - { - // 2d and 3d indices are only supported for axis == 1 - ARM_COMPUTE_ERROR_ON(actual_axis != 1 && indices_shape.num_dimensions() > 1); - } - } + output_shape.set(dim_no, indices_shape[dim_no - actual_axis]); + } + + for(; dim_no < input_num_dims + indices_num_dims - 1; ++dim_no) + { + output_shape.set(dim_no, input_shape[dim_no + 1 - indices_num_dims]); } + + ARM_COMPUTE_ERROR_ON(input_shape.total_size() * indices_shape.total_size() != output_shape.total_size() * input_shape[actual_axis]); + return output_shape; } } // namespace shape_calculator |