diff options
Diffstat (limited to 'arm_compute/core/utils/misc/ShapeCalculator.h')
-rw-r--r-- | arm_compute/core/utils/misc/ShapeCalculator.h | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index 38906dfc9b..c625a07a7f 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -754,6 +754,28 @@ inline TensorShape calculate_width_concatenate_shape(const std::vector<T *> &inp return out_shape; } + +inline TensorShape compute_stack_shape(const ITensorInfo &a, unsigned int axis, unsigned int num_tensors) +{ + ARM_COMPUTE_ERROR_ON(axis > a.num_dimensions()); + ARM_COMPUTE_ERROR_ON(a.num_dimensions() > 4); + + TensorShape shape_out{ a.tensor_shape() }; + shape_out.set(axis, num_tensors); + + unsigned int i_shift = 0; + + for(unsigned int i = 0; i < a.num_dimensions(); ++i) + { + if(i == axis) + { + i_shift++; + } + + shape_out.set(i + i_shift, a.tensor_shape()[i]); + } + return shape_out; +} } // namespace shape_calculator } // namespace misc } // namespace arm_compute |