diff options
Diffstat (limited to 'arm_compute')
-rw-r--r-- | arm_compute/core/utils/misc/ShapeCalculator.h | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index 84c0ee5034..9e7c981814 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -746,6 +746,35 @@ inline TensorShape compute_deep_convolution_shape(const ITensorInfo &input, cons return compute_deep_convolution_shape(input.tensor_shape(), input.data_layout(), weights.tensor_shape(), conv_info); } +/** Calculate the indirect buffer output shape used by the indirect convolution function + * + * @param[in] input_shape Input tensor shape + * @param[in] input_data_layout Input data layout + * @param[in] weights_shape Weights tensor shape + * @param[in] conv_info Contains padding and stride information + * @param[in] desc Contains the direct/indirect convolution compute arguments, such as the tiling dimensions + * + * @return the calculated shape + */ +inline TensorShape compute_indirect_buffer_shape(const TensorShape &input_shape, DataLayout input_data_layout, const TensorShape &weights_shape, const PadStrideInfo &conv_info, + const DirectConvComputeKernelInfo &desc) +{ + ARM_COMPUTE_ERROR_ON_MSG(input_data_layout != DataLayout::NHWC, "The data layout can only be NHWC"); + ARM_COMPUTE_ERROR_ON_MSG(desc.m0 <= 0 || desc.m0 > 8, "M0 can only be greater than 0 and less than or equal to 8"); + + const unsigned int m0 = desc.m0; + const unsigned int kw = weights_shape[1]; + const unsigned int kh = weights_shape[2]; + + TensorShape output_conv2d_shape = compute_deep_convolution_shape(input_shape, input_data_layout, weights_shape, conv_info); + + const unsigned int output_w = m0 * kw * kh; + const unsigned int output_h = DIV_CEIL(output_conv2d_shape[1] * output_conv2d_shape[2], m0); + const unsigned int output_b = output_conv2d_shape[3]; + + return TensorShape(output_w, output_h, output_b); +} + /** Calculate the min/max shape output shape of a tensor * * @param[in] input Input tensor info |