aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core')
-rw-r--r--arm_compute/core/Helpers.h23
-rw-r--r--arm_compute/core/Helpers.inl22
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h53
3 files changed, 61 insertions, 37 deletions
diff --git a/arm_compute/core/Helpers.h b/arm_compute/core/Helpers.h
index fd6e94c079..f19e1e12e0 100644
--- a/arm_compute/core/Helpers.h
+++ b/arm_compute/core/Helpers.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2021 Arm Limited.
+ * Copyright (c) 2016-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -55,6 +55,16 @@ public:
*/
Iterator(const ITensor *tensor, const Window &window);
+ /** Create a container iterator for the tensor with the specified number of dimensions, stride, buffer pointer and window.
+ *
+ * @param[in] num_dims The number of dimensions.
+ * @param[in] strides The strides in bytes.
+ * @param[in] buffer The data buffer.
+ * @param[in] offset The offset in bytes from the beginning of the buffer to the first element of the tensor.
+ * @param[in] window The window which will be used to iterate over the tensor.
+ */
+ Iterator(size_t num_dims, const Strides &strides, uint8_t *buffer, size_t offset, const Window &window);
+
/** Increment the iterator along the specified dimension of the step value associated to the dimension.
*
* @warning It is the caller's responsibility to call increment(dimension+1) when reaching the end of a dimension, the iterator will not check for overflow.
@@ -86,6 +96,17 @@ public:
void reset(size_t dimension);
private:
+
+ /** Initialize a container iterator for the tensor with the specified number of dimensions, stride, buffer pointer and window.
+ *
+ * @param[in] num_dims The number of dimensions.
+ * @param[in] strides The strides in bytes.
+ * @param[in] buffer The data buffer.
+ * @param[in] offset The offset in bytes from the beginning of the buffer to the first element of the tensor.
+ * @param[in] window The window which will be used to iterate over the tensor.
+ */
+ void initialize(size_t num_dims, const Strides &strides, uint8_t *buffer, size_t offset, const Window &window);
+
uint8_t *_ptr;
class Dimension
diff --git a/arm_compute/core/Helpers.inl b/arm_compute/core/Helpers.inl
index a910521f94..ff902bba20 100644
--- a/arm_compute/core/Helpers.inl
+++ b/arm_compute/core/Helpers.inl
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2021 Arm Limited.
+ * Copyright (c) 2016-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -98,13 +98,23 @@ inline Iterator::Iterator(const ITensor *tensor, const Window &win)
ARM_COMPUTE_ERROR_ON(tensor == nullptr);
ARM_COMPUTE_ERROR_ON(tensor->info() == nullptr);
- const ITensorInfo *info = tensor->info();
- const Strides &strides = info->strides_in_bytes();
+ initialize(tensor->info()->num_dimensions(), tensor->info()->strides_in_bytes(), tensor->buffer(), tensor->info()->offset_first_element_in_bytes(), win);
+}
+
+inline Iterator::Iterator(size_t num_dims, const Strides &strides, uint8_t *buffer, size_t offset, const Window &win)
+ : Iterator()
+{
+ initialize(num_dims, strides, buffer, offset, win);
+}
+
+inline void Iterator::initialize(size_t num_dims, const Strides &strides, uint8_t *buffer, size_t offset, const Window &win)
+{
+ ARM_COMPUTE_ERROR_ON(buffer == nullptr);
- _ptr = tensor->buffer() + info->offset_first_element_in_bytes();
+ _ptr = buffer + offset;
//Initialize the stride for each dimension and calculate the position of the first element of the iteration:
- for(unsigned int n = 0; n < info->num_dimensions(); ++n)
+ for(unsigned int n = 0; n < num_dims; ++n)
{
_dims[n]._stride = win[n].step() * strides[n];
std::get<0>(_dims)._dim_start += static_cast<size_t>(strides[n]) * win[n].start();
@@ -116,7 +126,7 @@ inline Iterator::Iterator(const ITensor *tensor, const Window &win)
_dims[n]._dim_start = std::get<0>(_dims)._dim_start;
}
- ARM_COMPUTE_ERROR_ON_WINDOW_DIMENSIONS_GTE(win, info->num_dimensions());
+ ARM_COMPUTE_ERROR_ON_WINDOW_DIMENSIONS_GTE(win, num_dims);
}
inline void Iterator::increment(const size_t dimension)
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index 9e7c981814..94bd3aca03 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2022 Arm Limited.
+ * Copyright (c) 2017-2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -1537,39 +1537,32 @@ inline TensorShape compute_pool3d_shape(const TensorShape &src, Pooling3dLayerIn
*/
inline TensorShape compute_gather_shape(const TensorShape &input_shape, const TensorShape &indices_shape, uint32_t actual_axis)
{
- ARM_COMPUTE_ERROR_ON(input_shape.num_dimensions() > 4);
- ARM_COMPUTE_ERROR_ON(actual_axis >= input_shape.num_dimensions());
- ARM_COMPUTE_ERROR_ON(indices_shape.num_dimensions() > 3);
- TensorShape output_shape = input_shape;
- if(indices_shape.num_dimensions() == 1u)
+ const auto input_num_dims = input_shape.num_dimensions();
+ const auto indices_num_dims = indices_shape.num_dimensions();
+
+ ARM_COMPUTE_ERROR_ON(actual_axis >= input_num_dims);
+ ARM_COMPUTE_ERROR_ON(input_num_dims + indices_num_dims - 1 > Coordinates::num_max_dimensions);
+
+ TensorShape output_shape;
+ size_t dim_no = 0;
+
+ for(; dim_no < actual_axis; ++dim_no)
{
- output_shape[actual_axis] = indices_shape[0];
+ output_shape.set(dim_no, input_shape[dim_no]);
}
- else
+
+ for(; dim_no < actual_axis + indices_num_dims; ++dim_no)
{
- const auto ind_num_dims
- {
- indices_shape.num_dimensions()
- };
- output_shape.shift_right(ind_num_dims - 1);
- switch(actual_axis)
- {
- case 1:
- {
- output_shape[0] = input_shape[0];
- for(size_t idx = 0; idx < ind_num_dims; ++idx)
- {
- output_shape.set(actual_axis + idx, indices_shape[idx], false);
- }
- break;
- }
- default:
- {
- // 2d and 3d indices are only supported for axis == 1
- ARM_COMPUTE_ERROR_ON(actual_axis != 1 && indices_shape.num_dimensions() > 1);
- }
- }
+ output_shape.set(dim_no, indices_shape[dim_no - actual_axis]);
+ }
+
+ for(; dim_no < input_num_dims + indices_num_dims - 1; ++dim_no)
+ {
+ output_shape.set(dim_no, input_shape[dim_no + 1 - indices_num_dims]);
}
+
+ ARM_COMPUTE_ERROR_ON(input_shape.total_size() * indices_shape.total_size() != output_shape.total_size() * input_shape[actual_axis]);
+
return output_shape;
}
} // namespace shape_calculator