aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp
diff options
context:
space:
mode:
authorGunes Bayir <gunes.bayir@arm.com>2023-06-19 21:33:51 +0100
committerGunes Bayir <gunes.bayir@arm.com>2023-06-29 13:23:45 +0000
commit00474e99260da69c5abd14277d0dd0b6de209904 (patch)
tree28238ebbf4721d7aca6fbf6a23658fbe056da055 /src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp
parent7a698a38c625047bd558027d4cbc493f063739f5 (diff)
downloadComputeLibrary-00474e99260da69c5abd14277d0dd0b6de209904.tar.gz
Implement FP32/16 MatMul Lhs T Rhs T/NT kernel using MMUL extension
Resolves: COMPMID-6196, COMPMID-6197 Change-Id: I22a1c32686eb70e7676c8b4d64a76dbaeb638cb3 Signed-off-by: Gunes Bayir <gunes.bayir@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9798 Tested-by: Arm Jenkins <bsgcomp@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Viet-Hoa Do <viet-hoa.do@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp')
-rw-r--r--src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp12
1 files changed, 7 insertions, 5 deletions
diff --git a/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp b/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp
index 06a0bdee17..4630ec08e9 100644
--- a/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp
+++ b/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp
@@ -64,14 +64,17 @@ Status validate_matmul_kernel_info(const MatMulKernelInfo &matmul_kernel_info)
const int n0 = matmul_kernel_info.n0;
const int k0 = matmul_kernel_info.k0;
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((adj_lhs), "adj_lhs is not supported yet");
-
// Validate M0
ARM_COMPUTE_RETURN_ERROR_ON_MSG(m0 < 1, "Only positive integers are supported for M0");
+ if(adj_lhs)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((m0 != 1) && (m0 != 2) && (m0 != 3) && (m0 != 4) && (m0 != 8) && (m0 != 16), "Only 1,2,3,4,8,16 are supported for M0 for Lhs transposed");
+ }
+
// Validate N0
ARM_COMPUTE_RETURN_ERROR_ON_MSG(n0 < 1, "Only positive integers are supported for N0");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(((n0 & (n0 - 1)) && (n0 != 3)) || (n0 > 16), "Only 1,2,3,4,8,16 are supported for N0");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG((n0 != 1) && (n0 != 2) && (n0 != 3) && (n0 != 4) && (n0 != 8) && (n0 != 16), "Only 1,2,3,4,8,16 are supported for N0");
// Validate K0
ARM_COMPUTE_RETURN_ERROR_ON_MSG((k0 != 1), "Only 1 is supported for k0");
@@ -81,8 +84,7 @@ Status validate_matmul_kernel_info(const MatMulKernelInfo &matmul_kernel_info)
Status validate_input_shapes(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const MatMulKernelInfo &matmul_kernel_info)
{
- ARM_COMPUTE_UNUSED(matmul_kernel_info);
- const size_t lhs_k = lhs_shape.x();
+ const size_t lhs_k = matmul_kernel_info.adj_lhs ? lhs_shape.y() : lhs_shape.x();
const size_t rhs_k = matmul_kernel_info.adj_rhs ? rhs_shape.x() : rhs_shape.y();
ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_k != rhs_k, "K dimension in Lhs and Rhs matrices must match.");