diff options
Diffstat (limited to 'arm_compute/core/utils/misc/ShapeCalculator.h')
-rw-r--r-- | arm_compute/core/utils/misc/ShapeCalculator.h | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index 7eab17ba11..010501454f 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -26,6 +26,7 @@ #include "arm_compute/core/Helpers.h" #include "arm_compute/core/ITensorInfo.h" +#include "arm_compute/core/KernelDescriptors.h" #include "arm_compute/core/Utils.h" #include "arm_compute/core/utils/helpers/tensor_transform.h" @@ -851,6 +852,8 @@ inline TensorShape compute_mm_shape(const ITensorInfo &input0, const ITensorInfo /** Calculate the matrix multiplication output shape of two tensors * + * @note Deprecated. Remove when GEMMReshapeInfo is not used anymore by any other kernels + * * @param[in] input0 First input tensor info * @param[in] input1 Second input tensor info * @param[in] gemm_info GEMM reshape info @@ -888,6 +891,43 @@ inline TensorShape compute_mm_shape(const ITensorInfo &input0, const ITensorInfo /** Calculate the matrix multiplication output shape of two tensors * + * @param[in] input0 First input tensor info + * @param[in] input1 Second input tensor info + * @param[in] gemm_info GEMM kernel info used to retrieve the original dimensions of the input matrices + * + * @return the calculated shape + */ +inline TensorShape compute_mm_shape(const ITensorInfo &input0, const ITensorInfo &input1, const GEMMKernelInfo &gemm_info) +{ + ARM_COMPUTE_ERROR_ON_MSG(input0.num_dimensions() > 4, "The number of dimensions for the matrix A must be <= 4"); + + const bool reinterpret_input_as_3d = gemm_info.reinterpret_input_as_3d; + const bool reinterpret_output_as_3d = gemm_info.depth_output_gemm3d != 0; + const unsigned int depth_output_gemm3d = reinterpret_output_as_3d ? gemm_info.depth_output_gemm3d : 1; + + TensorShape output_shape{ input0.tensor_shape() }; + + if(!reinterpret_input_as_3d && !reinterpret_output_as_3d) + { + output_shape.set(0, gemm_info.n); + output_shape.set(1, gemm_info.m); + } + else + { + // If the output of GEMM has to be reinterpreted as 3D, the number of input0 rows (M) is obtained collapsing the second and third + // dimension of the output tensor + const unsigned int batch_size = reinterpret_input_as_3d ? input0.tensor_shape()[3] : input0.tensor_shape()[2]; + output_shape.set(0, gemm_info.n); + output_shape.set(1, gemm_info.m / depth_output_gemm3d); + output_shape.set(2, reinterpret_output_as_3d ? depth_output_gemm3d : batch_size); + output_shape.set(3, reinterpret_output_as_3d ? batch_size : 1); + } + + return output_shape; +} + +/** Calculate the matrix multiplication output shape of two tensors + * * @param[in] input Input tensor info * @param[in] gemm_3d_depth (Optional) GEMM 3d depth * @param[in] batch_size_on_z (Optional) True if batch size is on z axis |