aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/TensorShape.h
diff options
context:
space:
mode:
authorGiorgio Arena <giorgio.arena@arm.com>2020-12-11 13:39:02 +0000
committerGiorgio Arena <giorgio.arena@arm.com>2020-12-11 14:53:45 +0000
commitec241b48ea7481e797285788fd68e5e1d42382bb (patch)
tree9fa8354cc5ec6f018ab04adf8ed68612cfd0e9a4 /arm_compute/core/TensorShape.h
parentc53266e45f3c8c07dff88c61e5bfa01c6d3ba3f0 (diff)
downloadComputeLibrary-ec241b48ea7481e797285788fd68e5e1d42382bb.tar.gz
Allow TensorShape to increment num_dimensions for new unit dimensions, without correction
Resolves: COMPMID-4053 Change-Id: Ie0b58b393e07518deb2c1fe4f82cbf0ce257f39a Signed-off-by: Giorgio Arena <giorgio.arena@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/4691 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: SiCong Li <sicong.li@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'arm_compute/core/TensorShape.h')
-rw-r--r--arm_compute/core/TensorShape.h7
1 files changed, 4 insertions, 3 deletions
diff --git a/arm_compute/core/TensorShape.h b/arm_compute/core/TensorShape.h
index fe3921f766..7c5ea8d1b7 100644
--- a/arm_compute/core/TensorShape.h
+++ b/arm_compute/core/TensorShape.h
@@ -71,11 +71,12 @@ public:
*
* @param[in] dimension Dimension for which the value is set.
* @param[in] value Value to be set for the dimension.
- * @param[in] apply_dim_correction Flag to state whether apply dimension correction after setting one dimension. E.g. when permuting NCHW -> NHWC, 1x1x2 would become 2x1x1, but _num_dimensions should be 3 rather than 1.
+ * @param[in] apply_dim_correction (Optional) Flag to state whether apply dimension correction after setting one dimension. E.g. when permuting NCHW -> NHWC, 1x1x2 would become 2x1x1, but _num_dimensions should be 3 rather than 1.
+ * @param[in] increase_dim_unit (Optional) Set to true if new unit dimensions increase the number of dimensions of the shape.
*
* @return *this.
*/
- TensorShape &set(size_t dimension, size_t value, bool apply_dim_correction = true)
+ TensorShape &set(size_t dimension, size_t value, bool apply_dim_correction = true, bool increase_dim_unit = true)
{
// Clear entire shape if one dimension is zero
if(value == 0)
@@ -90,7 +91,7 @@ public:
// Set the specified dimension and increase the number of dimensions if
// necessary
- Dimensions::set(dimension, value, false);
+ Dimensions::set(dimension, value, increase_dim_unit);
// Correct number dimensions to ignore trailing dimensions of size 1
if(apply_dim_correction)