aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/utils/misc/ShapeCalculator.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/utils/misc/ShapeCalculator.h')
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h17
1 files changed, 13 insertions, 4 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index c4c360842f..080d63f60d 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -1179,15 +1179,24 @@ inline TensorShape compute_tiled_shape(const TensorShape &input_shape, const Mul
/** Calculate the reduced shape of a tensor given an axis
*
- * @param[in] input Input tensor info
- * @param[in] axis Axis on which to perform reduction
+ * @param[in] input Input tensor info
+ * @param[in] axis Axis on which to perform reduction
+ * @param[in] keep_dims (Optional) Whether to keep the dimension after reduction operation. Defaults to true.
*
* @return the calculated shape
*/
-inline TensorShape compute_reduced_shape(const TensorShape &input, unsigned int axis)
+inline TensorShape compute_reduced_shape(const TensorShape &input, unsigned int axis, bool keep_dims = true)
{
TensorShape output_shape{ input };
- output_shape.set(axis, 1);
+
+ if(!keep_dims)
+ {
+ output_shape.remove_dimension(axis);
+ }
+ else
+ {
+ output_shape.set(axis, 1);
+ }
return output_shape;
}