From ae99b6eac40c7c3cb5eb465f3cbe4b522eff0488 Mon Sep 17 00:00:00 2001 From: Giorgio Arena Date: Thu, 1 Aug 2019 14:22:12 +0100 Subject: COMPMID-1965 Extend CLGEMMMatrixMultiplyReshapedKernel to support transposed LHS (t) and not-transpose RHS Change-Id: I437a00d7213fefd6f4365071b46174d44df8b85c Signed-off-by: Giorgio Arena Reviewed-on: https://review.mlplatform.org/c/1677 Tested-by: Arm Jenkins Reviewed-by: Gian Marco Iodice Comments-Addressed: Arm Jenkins --- src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp') diff --git a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp index 63451b49b8..9630caefd8 100644 --- a/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixMultiplyReshapedKernel.cpp @@ -67,9 +67,8 @@ Status validate_arguments(const ITensorInfo *input0, const ITensorInfo *input1, ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input0, input1); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input0->num_dimensions() > 4, "The number of dimensions for the LHS matrix must be <= 4"); ARM_COMPUTE_RETURN_ERROR_ON_MSG(input1->num_dimensions() > 3, "The number of dimensions for the RHS matrix must be <= 3"); - ARM_COMPUTE_RETURN_ERROR_ON(lhs_info.transpose); - ARM_COMPUTE_RETURN_ERROR_ON(!rhs_info.transpose); ARM_COMPUTE_RETURN_ERROR_ON(lhs_info.k0 != rhs_info.k0); + ARM_COMPUTE_RETURN_ERROR_ON(lhs_info.transpose == rhs_info.transpose); ARM_COMPUTE_RETURN_ERROR_ON_MSG(((lhs_info.k0 & (lhs_info.k0 - 1)) && lhs_info.k0 != 3), "Only 2,3,4,8,16 are supported for k0"); ARM_COMPUTE_RETURN_ERROR_ON(lhs_info.k0 > 16); ARM_COMPUTE_RETURN_ERROR_ON(lhs_info.m0 < 2 || lhs_info.m0 > 8); @@ -253,6 +252,7 @@ void CLGEMMMatrixMultiplyReshapedKernel::configure(const ICLTensor *input0, cons build_opts.add_option_if(!_slide_matrix_b, "-DMATRIX_B_DEPTH=" + support::cpp11::to_string(input1->info()->dimension(2))); build_opts.add_option_if(lhs_info.interleave, "-DLHS_INTERLEAVE"); build_opts.add_option_if(rhs_info.interleave, "-DRHS_INTERLEAVE"); + build_opts.add_option_if(lhs_info.transpose, "-DLHS_TRANSPOSE"); build_opts.add_option_if(_use_dummy_work_items, "-DDUMMY_WORK_ITEMS"); build_opts.add_option("-DM=" + support::cpp11::to_string(gemm_info.m)); build_opts.add_option("-DN=" + support::cpp11::to_string(gemm_info.n)); -- cgit v1.2.1