diff options
Diffstat (limited to 'arm_compute/core/Dimensions.h')
-rw-r--r-- | arm_compute/core/Dimensions.h | 44 |
1 files changed, 32 insertions, 12 deletions
diff --git a/arm_compute/core/Dimensions.h b/arm_compute/core/Dimensions.h index fbaef3a8f0..bb8692d70a 100644 --- a/arm_compute/core/Dimensions.h +++ b/arm_compute/core/Dimensions.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017-2019 ARM Limited. + * Copyright (c) 2017-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -29,6 +29,7 @@ #include <algorithm> #include <array> #include <functional> +#include <limits> #include <numeric> namespace arm_compute @@ -49,8 +50,7 @@ public: * @param[in] dims Values to initialize the dimensions. */ template <typename... Ts> - explicit Dimensions(Ts... dims) - : _id{ { static_cast<T>(dims)... } }, _num_dimensions{ sizeof...(dims) } + explicit Dimensions(Ts... dims) : _id{{static_cast<T>(dims)...}}, _num_dimensions{sizeof...(dims)} { } @@ -68,14 +68,19 @@ public: /** Accessor to set the value of one of the dimensions. * - * @param[in] dimension Dimension for which the value is set. - * @param[in] value Value to be set for the dimension. + * @param[in] dimension Dimension for which the value is set. + * @param[in] value Value to be set for the dimension. + * @param[in] increase_dim_unit (Optional) Set to true if new unit dimensions increase the number of dimensions (e.g. for Coordinates), false otherwise (e.g. for TensorShapes) */ - void set(size_t dimension, T value) + void set(size_t dimension, T value, bool increase_dim_unit = true) { ARM_COMPUTE_ERROR_ON(dimension >= num_max_dimensions); - _id[dimension] = value; - _num_dimensions = std::max(_num_dimensions, dimension + 1); + _id[dimension] = value; + // Don't increase the number of dimensions if the new dimension is 1 + if (increase_dim_unit || value != 1) + { + _num_dimensions = std::max(_num_dimensions, dimension + 1); + } } /** Alias to access the size of the first dimension */ T x() const @@ -92,6 +97,21 @@ public: { return _id[2]; } + /** Increments the given dimension by a step size, avoiding overflows + * + * @note Precondition: dim < _num_dimensions + * + * @param[in] dim Dimension to increment. + * @param[in] step Step to increment @p dim by. + */ + void increment(size_t dim, T step = 1) + { + ARM_COMPUTE_ERROR_ON(dim >= _num_dimensions); + if ((std::numeric_limits<T>::max() - _id[dim]) >= step) + { + _id[dim] += step; + } + } /** Generic accessor to get the size of any dimension * * @note Precondition: dimension < Dimensions::num_max_dimensions @@ -141,7 +161,7 @@ public: const size_t last = std::min(_num_dimensions, first + n); - if(last > (first + 1)) + if (last > (first + 1)) { // Collapse dimensions into the first _id[first] = std::accumulate(&_id[first], &_id[last], 1, std::multiplies<T>()); @@ -175,7 +195,7 @@ public: void remove(size_t idx) { ARM_COMPUTE_ERROR_ON(_num_dimensions < 1); - if(idx >= _num_dimensions) + if (idx >= _num_dimensions) { return; } @@ -241,7 +261,7 @@ protected: ~Dimensions() = default; std::array<T, num_max_dimensions> _id; - size_t _num_dimensions{ 0 }; + size_t _num_dimensions{0}; }; /** Check that given dimensions are equal. @@ -268,5 +288,5 @@ inline bool operator!=(const Dimensions<T> &lhs, const Dimensions<T> &rhs) { return !(lhs == rhs); } -} +} // namespace arm_compute #endif /*ARM_COMPUTE_DIMENSIONS_H*/ |