diff options
Diffstat (limited to 'arm_compute/core/utils/misc/ShapeCalculator.h')
-rw-r--r-- | arm_compute/core/utils/misc/ShapeCalculator.h | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index cb04182c21..806149f83f 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -557,6 +557,16 @@ inline TensorShape compute_split_shape(const ITensorInfo *input, unsigned int ax return out_shape; } +inline TensorShape compute_space_to_batch_shape(const ITensorInfo *input, const int block_x, const int block_y, const Size2D &padding_left, const Size2D &padding_right) +{ + TensorShape output_shape{ input->tensor_shape() }; + output_shape.set(0, input->tensor_shape()[0] * block_x + padding_left.x() + padding_right.x()); + output_shape.set(1, input->tensor_shape()[1] * block_y + padding_left.y() + padding_right.y()); + output_shape.set(3, input->tensor_shape()[3] / (block_x * block_y)); + + return output_shape; +} + template <typename T> inline TensorShape extract_shape(T *data) { |