diff options
Diffstat (limited to 'arm_compute/core/TensorShape.h')
-rw-r--r-- | arm_compute/core/TensorShape.h | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/arm_compute/core/TensorShape.h b/arm_compute/core/TensorShape.h index d5532e8a6a..0c3d9414e1 100644 --- a/arm_compute/core/TensorShape.h +++ b/arm_compute/core/TensorShape.h @@ -69,12 +69,13 @@ public: /** Accessor to set the value of one of the dimensions. * - * @param[in] dimension Dimension for which the value is set. - * @param[in] value Value to be set for the dimension. + * @param[in] dimension Dimension for which the value is set. + * @param[in] value Value to be set for the dimension. + * @param[in] apply_dim_correction Flag to state whether apply dimension correction after setting one dimension. E.g. when permuting NCHW -> NHWC, 1x1x2 would become 2x1x1, but _num_dimensions should be 3 rather than 1. * * @return *this. */ - TensorShape &set(size_t dimension, size_t value) + TensorShape &set(size_t dimension, size_t value, bool apply_dim_correction = true) { // Clear entire shape if one dimension is zero if(value == 0) @@ -92,7 +93,10 @@ public: Dimensions::set(dimension, value); // Correct number dimensions to ignore trailing dimensions of size 1 - apply_dimension_correction(); + if(apply_dim_correction) + { + apply_dimension_correction(); + } } return *this; } |