From ec241b48ea7481e797285788fd68e5e1d42382bb Mon Sep 17 00:00:00 2001 From: Giorgio Arena Date: Fri, 11 Dec 2020 13:39:02 +0000 Subject: Allow TensorShape to increment num_dimensions for new unit dimensions, without correction Resolves: COMPMID-4053 Change-Id: Ie0b58b393e07518deb2c1fe4f82cbf0ce257f39a Signed-off-by: Giorgio Arena Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4691 Tested-by: Arm Jenkins Reviewed-by: SiCong Li Comments-Addressed: Arm Jenkins --- arm_compute/core/Dimensions.h | 2 +- arm_compute/core/Helpers.h | 2 +- arm_compute/core/TensorShape.h | 7 ++++--- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/arm_compute/core/Dimensions.h b/arm_compute/core/Dimensions.h index 0e6e1f6681..d487b997a3 100644 --- a/arm_compute/core/Dimensions.h +++ b/arm_compute/core/Dimensions.h @@ -70,7 +70,7 @@ public: * * @param[in] dimension Dimension for which the value is set. * @param[in] value Value to be set for the dimension. - * @param[in] increase_dim_unit (Optional) Set to true if unit dimension increase the number of dimensions (e.g. for Coordinates), false otherwise (e.g. for TensorShapes) + * @param[in] increase_dim_unit (Optional) Set to true if new unit dimensions increase the number of dimensions (e.g. for Coordinates), false otherwise (e.g. for TensorShapes) */ void set(size_t dimension, T value, bool increase_dim_unit = true) { diff --git a/arm_compute/core/Helpers.h b/arm_compute/core/Helpers.h index 5a8d6efe9d..b6635aba6d 100644 --- a/arm_compute/core/Helpers.h +++ b/arm_compute/core/Helpers.h @@ -145,7 +145,7 @@ inline void permute(TensorShape &shape, const PermutationVector &perm) for(unsigned int i = 0; i < perm.num_dimensions(); ++i) { size_t dimension_val = (perm[i] < shape.num_dimensions()) ? shape_copy[perm[i]] : 1; - shape.set(i, dimension_val, false); // Avoid changes in _num_dimension + shape.set(i, dimension_val, false, false); // Avoid changes in _num_dimension } } diff --git a/arm_compute/core/TensorShape.h b/arm_compute/core/TensorShape.h index fe3921f766..7c5ea8d1b7 100644 --- a/arm_compute/core/TensorShape.h +++ b/arm_compute/core/TensorShape.h @@ -71,11 +71,12 @@ public: * * @param[in] dimension Dimension for which the value is set. * @param[in] value Value to be set for the dimension. - * @param[in] apply_dim_correction Flag to state whether apply dimension correction after setting one dimension. E.g. when permuting NCHW -> NHWC, 1x1x2 would become 2x1x1, but _num_dimensions should be 3 rather than 1. + * @param[in] apply_dim_correction (Optional) Flag to state whether apply dimension correction after setting one dimension. E.g. when permuting NCHW -> NHWC, 1x1x2 would become 2x1x1, but _num_dimensions should be 3 rather than 1. + * @param[in] increase_dim_unit (Optional) Set to true if new unit dimensions increase the number of dimensions of the shape. * * @return *this. */ - TensorShape &set(size_t dimension, size_t value, bool apply_dim_correction = true) + TensorShape &set(size_t dimension, size_t value, bool apply_dim_correction = true, bool increase_dim_unit = true) { // Clear entire shape if one dimension is zero if(value == 0) @@ -90,7 +91,7 @@ public: // Set the specified dimension and increase the number of dimensions if // necessary - Dimensions::set(dimension, value, false); + Dimensions::set(dimension, value, increase_dim_unit); // Correct number dimensions to ignore trailing dimensions of size 1 if(apply_dim_correction) -- cgit v1.2.1