From 8aa985e6cd553f4e2cee6cab74b82fa626896299 Mon Sep 17 00:00:00 2001 From: Gian Marco Iodice Date: Tue, 27 Nov 2018 15:58:08 +0000 Subject: COMPMID-1725: Implement Pack Change-Id: I13f6e4c600f39355f69e015409bf30dafdc5e3aa Reviewed-on: https://review.mlplatform.org/332 Tested-by: Arm Jenkins Reviewed-by: Michele Di Giorgio --- arm_compute/core/utils/misc/ShapeCalculator.h | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) (limited to 'arm_compute/core/utils') diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index 38906dfc9b..c625a07a7f 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -754,6 +754,28 @@ inline TensorShape calculate_width_concatenate_shape(const std::vector &inp return out_shape; } + +inline TensorShape compute_stack_shape(const ITensorInfo &a, unsigned int axis, unsigned int num_tensors) +{ + ARM_COMPUTE_ERROR_ON(axis > a.num_dimensions()); + ARM_COMPUTE_ERROR_ON(a.num_dimensions() > 4); + + TensorShape shape_out{ a.tensor_shape() }; + shape_out.set(axis, num_tensors); + + unsigned int i_shift = 0; + + for(unsigned int i = 0; i < a.num_dimensions(); ++i) + { + if(i == axis) + { + i_shift++; + } + + shape_out.set(i + i_shift, a.tensor_shape()[i]); + } + return shape_out; +} } // namespace shape_calculator } // namespace misc } // namespace arm_compute -- cgit v1.2.1