aboutsummaryrefslogtreecommitdiff
path: root/arm_compute/core/utils
diff options
context:
space:
mode:
Diffstat (limited to 'arm_compute/core/utils')
-rw-r--r--arm_compute/core/utils/misc/ShapeCalculator.h16
1 files changed, 16 insertions, 0 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h
index 1a86d27727..9c7cfecd4c 100644
--- a/arm_compute/core/utils/misc/ShapeCalculator.h
+++ b/arm_compute/core/utils/misc/ShapeCalculator.h
@@ -57,6 +57,22 @@ inline TensorShape compute_permutation_output_shape(const ITensorInfo &input, co
permute(output_shape, perm);
return output_shape;
}
+inline TensorShape compute_reorg_output_shape(const ITensorInfo &input, int32_t stride)
+{
+ ARM_COMPUTE_ERROR_ON(stride <= 0);
+
+ const DataLayout data_layout = input.data_layout();
+ const unsigned int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH);
+ const unsigned int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT);
+ const unsigned int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL);
+
+ TensorShape output_shape{ input.tensor_shape() };
+ output_shape.set(width_idx, input.tensor_shape()[width_idx] / stride);
+ output_shape.set(height_idx, input.tensor_shape()[height_idx] / stride);
+ output_shape.set(channel_idx, input.tensor_shape()[channel_idx] * stride * stride);
+
+ return output_shape;
+}
inline TensorShape compute_weights_reshaped_shape(const ITensorInfo &weights, bool has_bias = false, unsigned int num_groups = 1)
{
// Number of groups greater than one are only supported for NCHW data layout, and the number of weights must be a multiple of it.