From 00474e99260da69c5abd14277d0dd0b6de209904 Mon Sep 17 00:00:00 2001 From: Gunes Bayir Date: Mon, 19 Jun 2023 21:33:51 +0100 Subject: Implement FP32/16 MatMul Lhs T Rhs T/NT kernel using MMUL extension Resolves: COMPMID-6196, COMPMID-6197 Change-Id: I22a1c32686eb70e7676c8b4d64a76dbaeb638cb3 Signed-off-by: Gunes Bayir Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9798 Tested-by: Arm Jenkins Comments-Addressed: Arm Jenkins Reviewed-by: Viet-Hoa Do Benchmark: Arm Jenkins --- src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) (limited to 'src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp') diff --git a/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp b/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp index 06a0bdee17..4630ec08e9 100644 --- a/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp +++ b/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp @@ -64,14 +64,17 @@ Status validate_matmul_kernel_info(const MatMulKernelInfo &matmul_kernel_info) const int n0 = matmul_kernel_info.n0; const int k0 = matmul_kernel_info.k0; - ARM_COMPUTE_RETURN_ERROR_ON_MSG((adj_lhs), "adj_lhs is not supported yet"); - // Validate M0 ARM_COMPUTE_RETURN_ERROR_ON_MSG(m0 < 1, "Only positive integers are supported for M0"); + if(adj_lhs) + { + ARM_COMPUTE_RETURN_ERROR_ON_MSG((m0 != 1) && (m0 != 2) && (m0 != 3) && (m0 != 4) && (m0 != 8) && (m0 != 16), "Only 1,2,3,4,8,16 are supported for M0 for Lhs transposed"); + } + // Validate N0 ARM_COMPUTE_RETURN_ERROR_ON_MSG(n0 < 1, "Only positive integers are supported for N0"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(((n0 & (n0 - 1)) && (n0 != 3)) || (n0 > 16), "Only 1,2,3,4,8,16 are supported for N0"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG((n0 != 1) && (n0 != 2) && (n0 != 3) && (n0 != 4) && (n0 != 8) && (n0 != 16), "Only 1,2,3,4,8,16 are supported for N0"); // Validate K0 ARM_COMPUTE_RETURN_ERROR_ON_MSG((k0 != 1), "Only 1 is supported for k0"); @@ -81,8 +84,7 @@ Status validate_matmul_kernel_info(const MatMulKernelInfo &matmul_kernel_info) Status validate_input_shapes(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const MatMulKernelInfo &matmul_kernel_info) { - ARM_COMPUTE_UNUSED(matmul_kernel_info); - const size_t lhs_k = lhs_shape.x(); + const size_t lhs_k = matmul_kernel_info.adj_lhs ? lhs_shape.y() : lhs_shape.x(); const size_t rhs_k = matmul_kernel_info.adj_rhs ? rhs_shape.x() : rhs_shape.y(); ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_k != rhs_k, "K dimension in Lhs and Rhs matrices must match."); -- cgit v1.2.1