aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/utils/misc/ShapeCalculator.h
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/utils/misc/ShapeCalculator.h')
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h10
1 files changed, 10 insertions, 0 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index cb04182c21..806149f83f 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -557,6 +557,16 @@ inline TensorShape compute_split_shape(const ITensorInfo *input, unsigned int ax
return out_shape;
}
+inline TensorShape compute_space_to_batch_shape(const ITensorInfo *input, const int block_x, const int block_y, const Size2D &padding_left, const Size2D &padding_right)
+{
+ TensorShape output_shape{ input->tensor_shape() };
+ output_shape.set(0, input->tensor_shape()[0] * block_x + padding_left.x() + padding_right.x());
+ output_shape.set(1, input->tensor_shape()[1] * block_y + padding_left.y() + padding_right.y());
+ output_shape.set(3, input->tensor_shape()[3] / (block_x * block_y));
+
+ return output_shape;
+}
+
template <typename T>
inline TensorShape extract_shape(T *data)
{