aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/TensorShape.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/TensorShape.h')
-rw-r--r--arm_compute/core/TensorShape.h41
1 files changed, 22 insertions, 19 deletions
diff --git a/arm_compute/core/TensorShape.h b/arm_compute/core/TensorShape.h
index b6ab9dc75a..c1707e262f 100644
--- a/arm_compute/core/TensorShape.h
+++ b/arm_compute/core/TensorShape.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2021 Arm Limited.
+ * Copyright (c) 2016-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -44,11 +44,10 @@ public:
* @param[in] dims Values to initialize the dimensions.
*/
template <typename... Ts>
- TensorShape(Ts... dims)
- : Dimensions{ dims... }
+ TensorShape(Ts... dims) : Dimensions{dims...}
{
// Initialize unspecified dimensions to 1
- if(_num_dimensions > 0)
+ if (_num_dimensions > 0)
{
std::fill(_id.begin() + _num_dimensions, _id.end(), 1);
}
@@ -79,7 +78,7 @@ public:
TensorShape &set(size_t dimension, size_t value, bool apply_dim_correction = true, bool increase_dim_unit = true)
{
// Clear entire shape if one dimension is zero
- if(value == 0)
+ if (value == 0)
{
_num_dimensions = 0;
std::fill(_id.begin(), _id.end(), 0);
@@ -94,7 +93,7 @@ public:
Dimensions::set(dimension, value, increase_dim_unit);
// Correct number dimensions to ignore trailing dimensions of size 1
- if(apply_dim_correction)
+ if (apply_dim_correction)
{
apply_dimension_correction();
}
@@ -106,9 +105,10 @@ public:
*
* @note The upper dimensions of the tensor shape will be shifted down by 1
*
- * @param[in] n Dimension to remove
+ * @param[in] n Dimension to remove
+ * @param[in] apply_dim_correction (Optional) Flag to state whether apply dimension correction (removing trailing dimensions with size of 1) after removing a dimension.
*/
- void remove_dimension(size_t n)
+ void remove_dimension(size_t n, bool apply_dim_correction = true)
{
ARM_COMPUTE_ERROR_ON(_num_dimensions < 1);
ARM_COMPUTE_ERROR_ON(n >= _num_dimensions);
@@ -122,7 +122,10 @@ public:
std::fill(_id.begin() + _num_dimensions, _id.end(), 1);
// Correct number dimensions to ignore trailing dimensions of size 1
- apply_dimension_correction();
+ if (apply_dim_correction)
+ {
+ apply_dimension_correction();
+ }
}
/** Collapse the first n dimensions.
@@ -208,26 +211,26 @@ public:
* @return The broadcasted shape or an empty shape if the shapes are not broadcast compatible.
*/
template <typename... Shapes>
- static TensorShape broadcast_shape(const Shapes &... shapes)
+ static TensorShape broadcast_shape(const Shapes &...shapes)
{
TensorShape bc_shape;
- auto broadcast = [&bc_shape](const TensorShape & other)
+ auto broadcast = [&bc_shape](const TensorShape &other)
{
- if(bc_shape.num_dimensions() == 0)
+ if (bc_shape.num_dimensions() == 0)
{
bc_shape = other;
}
- else if(other.num_dimensions() != 0)
+ else if (other.num_dimensions() != 0)
{
- for(size_t d = 0; d < TensorShape::num_max_dimensions; ++d)
+ for (size_t d = 0; d < TensorShape::num_max_dimensions; ++d)
{
const size_t dim_min = std::min(bc_shape[d], other[d]);
const size_t dim_max = std::max(bc_shape[d], other[d]);
- if((dim_min != 1) && (dim_min != dim_max))
+ if ((dim_min != 1) && (dim_min != dim_max))
{
- bc_shape = TensorShape{ 0U };
+ bc_shape = TensorShape{0U};
break;
}
@@ -245,9 +248,9 @@ private:
/** Remove trailing dimensions of size 1 from the reported number of dimensions. */
void apply_dimension_correction()
{
- for(int i = static_cast<int>(_num_dimensions) - 1; i > 0; --i)
+ for (int i = static_cast<int>(_num_dimensions) - 1; i > 0; --i)
{
- if(_id[i] == 1)
+ if (_id[i] == 1)
{
--_num_dimensions;
}
@@ -258,5 +261,5 @@ private:
}
}
};
-}
+} // namespace arm_compute
#endif /*ARM_COMPUTE_TENSORSHAPE_H*/