aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/TensorShape.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/TensorShape.h')
-rw-r--r--arm_compute/core/TensorShape.h50
1 files changed, 27 insertions, 23 deletions
diff --git a/arm_compute/core/TensorShape.h b/arm_compute/core/TensorShape.h
index 57d8f6cf63..c1707e262f 100644
--- a/arm_compute/core/TensorShape.h
+++ b/arm_compute/core/TensorShape.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2016-2019 ARM Limited.
+ * Copyright (c) 2016-2021, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -36,7 +36,7 @@
namespace arm_compute
{
/** Shape of a tensor */
-class TensorShape : public Dimensions<uint32_t>
+class TensorShape : public Dimensions<size_t>
{
public:
/** Constructor to initialize the tensor shape.
@@ -44,11 +44,10 @@ public:
* @param[in] dims Values to initialize the dimensions.
*/
template <typename... Ts>
- TensorShape(Ts... dims)
- : Dimensions{ dims... }
+ TensorShape(Ts... dims) : Dimensions{dims...}
{
// Initialize unspecified dimensions to 1
- if(_num_dimensions > 0)
+ if (_num_dimensions > 0)
{
std::fill(_id.begin() + _num_dimensions, _id.end(), 1);
}
@@ -71,14 +70,15 @@ public:
*
* @param[in] dimension Dimension for which the value is set.
* @param[in] value Value to be set for the dimension.
- * @param[in] apply_dim_correction Flag to state whether apply dimension correction after setting one dimension. E.g. when permuting NCHW -> NHWC, 1x1x2 would become 2x1x1, but _num_dimensions should be 3 rather than 1.
+ * @param[in] apply_dim_correction (Optional) Flag to state whether apply dimension correction after setting one dimension. E.g. when permuting NCHW -> NHWC, 1x1x2 would become 2x1x1, but _num_dimensions should be 3 rather than 1.
+ * @param[in] increase_dim_unit (Optional) Set to true if new unit dimensions increase the number of dimensions of the shape.
*
* @return *this.
*/
- TensorShape &set(size_t dimension, size_t value, bool apply_dim_correction = true)
+ TensorShape &set(size_t dimension, size_t value, bool apply_dim_correction = true, bool increase_dim_unit = true)
{
// Clear entire shape if one dimension is zero
- if(value == 0)
+ if (value == 0)
{
_num_dimensions = 0;
std::fill(_id.begin(), _id.end(), 0);
@@ -90,10 +90,10 @@ public:
// Set the specified dimension and increase the number of dimensions if
// necessary
- Dimensions::set(dimension, value);
+ Dimensions::set(dimension, value, increase_dim_unit);
// Correct number dimensions to ignore trailing dimensions of size 1
- if(apply_dim_correction)
+ if (apply_dim_correction)
{
apply_dimension_correction();
}
@@ -105,9 +105,10 @@ public:
*
* @note The upper dimensions of the tensor shape will be shifted down by 1
*
- * @param[in] n Dimension to remove
+ * @param[in] n Dimension to remove
+ * @param[in] apply_dim_correction (Optional) Flag to state whether apply dimension correction (removing trailing dimensions with size of 1) after removing a dimension.
*/
- void remove_dimension(size_t n)
+ void remove_dimension(size_t n, bool apply_dim_correction = true)
{
ARM_COMPUTE_ERROR_ON(_num_dimensions < 1);
ARM_COMPUTE_ERROR_ON(n >= _num_dimensions);
@@ -121,7 +122,10 @@ public:
std::fill(_id.begin() + _num_dimensions, _id.end(), 1);
// Correct number dimensions to ignore trailing dimensions of size 1
- apply_dimension_correction();
+ if (apply_dim_correction)
+ {
+ apply_dimension_correction();
+ }
}
/** Collapse the first n dimensions.
@@ -207,26 +211,26 @@ public:
* @return The broadcasted shape or an empty shape if the shapes are not broadcast compatible.
*/
template <typename... Shapes>
- static TensorShape broadcast_shape(const Shapes &... shapes)
+ static TensorShape broadcast_shape(const Shapes &...shapes)
{
TensorShape bc_shape;
- auto broadcast = [&bc_shape](const TensorShape & other)
+ auto broadcast = [&bc_shape](const TensorShape &other)
{
- if(bc_shape.num_dimensions() == 0)
+ if (bc_shape.num_dimensions() == 0)
{
bc_shape = other;
}
- else if(other.num_dimensions() != 0)
+ else if (other.num_dimensions() != 0)
{
- for(size_t d = 0; d < TensorShape::num_max_dimensions; ++d)
+ for (size_t d = 0; d < TensorShape::num_max_dimensions; ++d)
{
const size_t dim_min = std::min(bc_shape[d], other[d]);
const size_t dim_max = std::max(bc_shape[d], other[d]);
- if((dim_min != 1) && (dim_min != dim_max))
+ if ((dim_min != 1) && (dim_min != dim_max))
{
- bc_shape = TensorShape{ 0U };
+ bc_shape = TensorShape{0U};
break;
}
@@ -244,9 +248,9 @@ private:
/** Remove trailing dimensions of size 1 from the reported number of dimensions. */
void apply_dimension_correction()
{
- for(int i = static_cast<int>(_num_dimensions) - 1; i > 0; --i)
+ for (int i = static_cast<int>(_num_dimensions) - 1; i > 0; --i)
{
- if(_id[i] == 1)
+ if (_id[i] == 1)
{
--_num_dimensions;
}
@@ -257,5 +261,5 @@ private:
}
}
};
-}
+} // namespace arm_compute
#endif /*ARM_COMPUTE_TENSORSHAPE_H*/