aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/utils/misc/ShapeCalculator.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/utils/misc/ShapeCalculator.h')
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h10
1 files changed, 9 insertions, 1 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index 11d20c919f..56f65d0ba8 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -113,7 +113,15 @@ inline TensorShape compute_interleaved_shape(const ITensorInfo &a, int mult_inte
const int M = a.dimension(1) * a.dimension(2);
const int height = std::ceil(M / static_cast<float>(interleave_width));
shape_interleaved_a.set(1, height);
- shape_interleaved_a.remove_dimension(2);
+
+ // When the data format is NHWC and the shapes are Nx1x1
+ // the tensor shape num_dimensions is automatically set to 1 instead of 3.
+ // To avoid failures by removing a dimension that doesn't exist
+ // check if the number of dimensions is greater than 2.
+ if(shape_interleaved_a.num_dimensions() > 2)
+ {
+ shape_interleaved_a.remove_dimension(2);
+ }
}
else
{