diff options
Diffstat (limited to 'arm_compute/core/utils/misc/ShapeCalculator.h')
-rw-r--r-- | arm_compute/core/utils/misc/ShapeCalculator.h | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index 1a86d27727..9c7cfecd4c 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -57,6 +57,22 @@ inline TensorShape compute_permutation_output_shape(const ITensorInfo &input, co permute(output_shape, perm); return output_shape; } +inline TensorShape compute_reorg_output_shape(const ITensorInfo &input, int32_t stride) +{ + ARM_COMPUTE_ERROR_ON(stride <= 0); + + const DataLayout data_layout = input.data_layout(); + const unsigned int width_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH); + const unsigned int height_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT); + const unsigned int channel_idx = get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL); + + TensorShape output_shape{ input.tensor_shape() }; + output_shape.set(width_idx, input.tensor_shape()[width_idx] / stride); + output_shape.set(height_idx, input.tensor_shape()[height_idx] / stride); + output_shape.set(channel_idx, input.tensor_shape()[channel_idx] * stride * stride); + + return output_shape; +} inline TensorShape compute_weights_reshaped_shape(const ITensorInfo &weights, bool has_bias = false, unsigned int num_groups = 1) { // Number of groups greater than one are only supported for NCHW data layout, and the number of weights must be a multiple of it. |