diff options
Diffstat (limited to 'arm_compute/core/utils/misc/ShapeCalculator.h')
-rw-r--r-- | arm_compute/core/utils/misc/ShapeCalculator.h | 34 |
1 files changed, 31 insertions, 3 deletions
diff --git a/arm_compute/core/utils/misc/ShapeCalculator.h b/arm_compute/core/utils/misc/ShapeCalculator.h index 6655cc1439..75a063f75c 100644 --- a/arm_compute/core/utils/misc/ShapeCalculator.h +++ b/arm_compute/core/utils/misc/ShapeCalculator.h @@ -21,8 +21,8 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. */ -#ifndef ARM_COMPUTE_MISC_SHAPE_CALCULATOR_H -#define ARM_COMPUTE_MISC_SHAPE_CALCULATOR_H +#ifndef ACL_ARM_COMPUTE_CORE_UTILS_MISC_SHAPECALCULATOR +#define ACL_ARM_COMPUTE_CORE_UTILS_MISC_SHAPECALCULATOR #include "arm_compute/core/Helpers.h" #include "arm_compute/core/ITensorInfo.h" @@ -1008,6 +1008,34 @@ inline TensorShape compute_mm_shape(const ITensorInfo &input0, const ITensorInfo /** Calculate the matrix multiplication output shape of two tensors * + * @param[in] input0 First input tensor info + * @param[in] input1 Second input tensor info + * @param[in] matmul_info Batch MatMul Kernel info to know which matrix is transposed + * + * @return the calculated shape + */ +inline TensorShape compute_batchmatmul_shape(const TensorShape &input0, const TensorShape &input1, const MatMulKernelInfo &matmul_info) +{ + TensorShape output_shape{ input0 }; + + if(matmul_info.adj_lhs) + { + output_shape.set(1, input0[0]); // The vertical (M) dimension + } + + if(matmul_info.adj_rhs) + { + output_shape.set(0, input1[1]); // The horizontal (N) dimension + } + else + { + output_shape.set(0, input1[0]); // The horizontal (N) dimension + } + + return output_shape; +} +/** Calculate the matrix multiplication output shape of two tensors + * * @param[in] input Input tensor info * @param[in] gemm_3d_depth (Optional) GEMM 3d depth * @param[in] batch_size_on_z (Optional) True if batch size is on z axis @@ -1579,4 +1607,4 @@ inline TensorShape compute_gather_shape(const TensorShape &input_shape, const Te } // namespace shape_calculator } // namespace misc } // namespace arm_compute -#endif /* ARM_COMPUTE_MISC_SHAPE_CALCULATOR_H */ +#endif /* ACL_ARM_COMPUTE_CORE_UTILS_MISC_SHAPECALCULATOR */ |