diff options
Diffstat (limited to 'arm_compute/core/TensorShape.h')
-rw-r--r-- | arm_compute/core/TensorShape.h | 41 |
1 files changed, 22 insertions, 19 deletions
diff --git a/arm_compute/core/TensorShape.h b/arm_compute/core/TensorShape.h index 7c5ea8d1b7..c1707e262f 100644 --- a/arm_compute/core/TensorShape.h +++ b/arm_compute/core/TensorShape.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2016-2020 Arm Limited. + * Copyright (c) 2016-2021, 2023 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -44,11 +44,10 @@ public: * @param[in] dims Values to initialize the dimensions. */ template <typename... Ts> - TensorShape(Ts... dims) - : Dimensions{ dims... } + TensorShape(Ts... dims) : Dimensions{dims...} { // Initialize unspecified dimensions to 1 - if(_num_dimensions > 0) + if (_num_dimensions > 0) { std::fill(_id.begin() + _num_dimensions, _id.end(), 1); } @@ -79,7 +78,7 @@ public: TensorShape &set(size_t dimension, size_t value, bool apply_dim_correction = true, bool increase_dim_unit = true) { // Clear entire shape if one dimension is zero - if(value == 0) + if (value == 0) { _num_dimensions = 0; std::fill(_id.begin(), _id.end(), 0); @@ -94,7 +93,7 @@ public: Dimensions::set(dimension, value, increase_dim_unit); // Correct number dimensions to ignore trailing dimensions of size 1 - if(apply_dim_correction) + if (apply_dim_correction) { apply_dimension_correction(); } @@ -106,9 +105,10 @@ public: * * @note The upper dimensions of the tensor shape will be shifted down by 1 * - * @param[in] n Dimension to remove + * @param[in] n Dimension to remove + * @param[in] apply_dim_correction (Optional) Flag to state whether apply dimension correction (removing trailing dimensions with size of 1) after removing a dimension. */ - void remove_dimension(size_t n) + void remove_dimension(size_t n, bool apply_dim_correction = true) { ARM_COMPUTE_ERROR_ON(_num_dimensions < 1); ARM_COMPUTE_ERROR_ON(n >= _num_dimensions); @@ -122,7 +122,10 @@ public: std::fill(_id.begin() + _num_dimensions, _id.end(), 1); // Correct number dimensions to ignore trailing dimensions of size 1 - apply_dimension_correction(); + if (apply_dim_correction) + { + apply_dimension_correction(); + } } /** Collapse the first n dimensions. @@ -208,26 +211,26 @@ public: * @return The broadcasted shape or an empty shape if the shapes are not broadcast compatible. */ template <typename... Shapes> - static TensorShape broadcast_shape(const Shapes &... shapes) + static TensorShape broadcast_shape(const Shapes &...shapes) { TensorShape bc_shape; - auto broadcast = [&bc_shape](const TensorShape & other) + auto broadcast = [&bc_shape](const TensorShape &other) { - if(bc_shape.num_dimensions() == 0) + if (bc_shape.num_dimensions() == 0) { bc_shape = other; } - else if(other.num_dimensions() != 0) + else if (other.num_dimensions() != 0) { - for(size_t d = 0; d < TensorShape::num_max_dimensions; ++d) + for (size_t d = 0; d < TensorShape::num_max_dimensions; ++d) { const size_t dim_min = std::min(bc_shape[d], other[d]); const size_t dim_max = std::max(bc_shape[d], other[d]); - if((dim_min != 1) && (dim_min != dim_max)) + if ((dim_min != 1) && (dim_min != dim_max)) { - bc_shape = TensorShape{ 0U }; + bc_shape = TensorShape{0U}; break; } @@ -245,9 +248,9 @@ private: /** Remove trailing dimensions of size 1 from the reported number of dimensions. */ void apply_dimension_correction() { - for(int i = static_cast<int>(_num_dimensions) - 1; i > 0; --i) + for (int i = static_cast<int>(_num_dimensions) - 1; i > 0; --i) { - if(_id[i] == 1) + if (_id[i] == 1) { --_num_dimensions; } @@ -258,5 +261,5 @@ private: } } }; -} +} // namespace arm_compute #endif /*ARM_COMPUTE_TENSORSHAPE_H*/ |