aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/Dimensions.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/Dimensions.h')
-rw-r--r--arm_compute/core/Dimensions.h44
1 files changed, 32 insertions, 12 deletions
diff --git a/arm_compute/core/Dimensions.h b/arm_compute/core/Dimensions.h
index fbaef3a8f0..bb8692d70a 100644
--- a/arm_compute/core/Dimensions.h
+++ b/arm_compute/core/Dimensions.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2021 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -29,6 +29,7 @@
#include <algorithm>
#include <array>
#include <functional>
+#include <limits>
#include <numeric>
namespace arm_compute
@@ -49,8 +50,7 @@ public:
* @param[in] dims Values to initialize the dimensions.
*/
template <typename... Ts>
- explicit Dimensions(Ts... dims)
- : _id{ { static_cast<T>(dims)... } }, _num_dimensions{ sizeof...(dims) }
+ explicit Dimensions(Ts... dims) : _id{{static_cast<T>(dims)...}}, _num_dimensions{sizeof...(dims)}
{
}
@@ -68,14 +68,19 @@ public:
/** Accessor to set the value of one of the dimensions.
*
- * @param[in] dimension Dimension for which the value is set.
- * @param[in] value Value to be set for the dimension.
+ * @param[in] dimension Dimension for which the value is set.
+ * @param[in] value Value to be set for the dimension.
+ * @param[in] increase_dim_unit (Optional) Set to true if new unit dimensions increase the number of dimensions (e.g. for Coordinates), false otherwise (e.g. for TensorShapes)
*/
- void set(size_t dimension, T value)
+ void set(size_t dimension, T value, bool increase_dim_unit = true)
{
ARM_COMPUTE_ERROR_ON(dimension >= num_max_dimensions);
- _id[dimension] = value;
- _num_dimensions = std::max(_num_dimensions, dimension + 1);
+ _id[dimension] = value;
+ // Don't increase the number of dimensions if the new dimension is 1
+ if (increase_dim_unit || value != 1)
+ {
+ _num_dimensions = std::max(_num_dimensions, dimension + 1);
+ }
}
/** Alias to access the size of the first dimension */
T x() const
@@ -92,6 +97,21 @@ public:
{
return _id[2];
}
+ /** Increments the given dimension by a step size, avoiding overflows
+ *
+ * @note Precondition: dim < _num_dimensions
+ *
+ * @param[in] dim Dimension to increment.
+ * @param[in] step Step to increment @p dim by.
+ */
+ void increment(size_t dim, T step = 1)
+ {
+ ARM_COMPUTE_ERROR_ON(dim >= _num_dimensions);
+ if ((std::numeric_limits<T>::max() - _id[dim]) >= step)
+ {
+ _id[dim] += step;
+ }
+ }
/** Generic accessor to get the size of any dimension
*
* @note Precondition: dimension < Dimensions::num_max_dimensions
@@ -141,7 +161,7 @@ public:
const size_t last = std::min(_num_dimensions, first + n);
- if(last > (first + 1))
+ if (last > (first + 1))
{
// Collapse dimensions into the first
_id[first] = std::accumulate(&_id[first], &_id[last], 1, std::multiplies<T>());
@@ -175,7 +195,7 @@ public:
void remove(size_t idx)
{
ARM_COMPUTE_ERROR_ON(_num_dimensions < 1);
- if(idx >= _num_dimensions)
+ if (idx >= _num_dimensions)
{
return;
}
@@ -241,7 +261,7 @@ protected:
~Dimensions() = default;
std::array<T, num_max_dimensions> _id;
- size_t _num_dimensions{ 0 };
+ size_t _num_dimensions{0};
};
/** Check that given dimensions are equal.
@@ -268,5 +288,5 @@ inline bool operator!=(const Dimensions<T> &lhs, const Dimensions<T> &rhs)
{
return !(lhs == rhs);
}
-}
+} // namespace arm_compute
#endif /*ARM_COMPUTE_DIMENSIONS_H*/