From e9b3ee2badebf91188c1cd0e59d6aaa30ed60985 Mon Sep 17 00:00:00 2001 From: Jakub Sujak Date: Mon, 17 Apr 2023 12:08:48 +0100 Subject: Connect CLMatMul function to quantized kernels and resolve NE BatchMatMul int_8 failures * Adapt the CLMatMul function and ClMatMul operator to use quantized kernels. * Add function-level tests. Resolves: COMPMID-5929 and COMPMID-5811 Change-Id: I5348cdcf07b8074c138e04dfef0a73399377accd Signed-off-by: Jakub Sujak Signed-off-by: Omar Al Khatib Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9575 Reviewed-by: Mohmun02 Reviewed-by: SiCong Li Comments-Addressed: Arm Jenkins Tested-by: Arm Jenkins --- src/gpu/cl/kernels/ClMatMulNativeKernel.cpp | 41 +++++++++++++++-------------- 1 file changed, 21 insertions(+), 20 deletions(-) (limited to 'src/gpu/cl/kernels/ClMatMulNativeKernel.cpp') diff --git a/src/gpu/cl/kernels/ClMatMulNativeKernel.cpp b/src/gpu/cl/kernels/ClMatMulNativeKernel.cpp index 47dba22e8f..8f53c1998f 100644 --- a/src/gpu/cl/kernels/ClMatMulNativeKernel.cpp +++ b/src/gpu/cl/kernels/ClMatMulNativeKernel.cpp @@ -119,35 +119,36 @@ ClMatMulNativeKernel::ClMatMulNativeKernel() { _type = CLKernelType::GEMM; } -Status ClMatMulNativeKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *output, const MatMulKernelInfo &matmul_kernel_info) + +Status ClMatMulNativeKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulKernelInfo &matmul_kernel_info) { - ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lhs, rhs, output); + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(lhs, rhs, dst); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs); ARM_COMPUTE_RETURN_ON_ERROR(validate_matmul_kernel_info(matmul_kernel_info)); ARM_COMPUTE_RETURN_ON_ERROR(validate_input_shapes(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)); ARM_COMPUTE_RETURN_ON_ERROR(validate_export_to_cl_image(rhs, matmul_kernel_info)); - if(output->total_size() != 0) + if(dst->total_size() != 0) { - const TensorInfo tensor_info_output = output->clone()->set_tensor_shape(misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_output); - ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, output); + const TensorInfo tensor_info_dst = dst->clone()->set_tensor_shape(misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(dst, &tensor_info_dst); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, dst); } return Status{}; } -void ClMatMulNativeKernel::configure(const ClCompileContext &compile_context, ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *output, const MatMulKernelInfo &matmul_kernel_info) +void ClMatMulNativeKernel::configure(const ClCompileContext &compile_context, ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulKernelInfo &matmul_kernel_info) { - ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, output, &compile_context, &matmul_kernel_info); - ARM_COMPUTE_LOG_PARAMS(lhs, rhs, output, matmul_kernel_info); - ARM_COMPUTE_ERROR_THROW_ON(validate(lhs, rhs, output, matmul_kernel_info)); + ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, dst, &compile_context, &matmul_kernel_info); + ARM_COMPUTE_LOG_PARAMS(lhs, rhs, dst, matmul_kernel_info); + ARM_COMPUTE_ERROR_THROW_ON(validate(lhs, rhs, dst, matmul_kernel_info)); - // output tensor auto initialization if not yet initialized - auto_init_if_empty(*output, lhs->clone()->set_tensor_shape(misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info))); + // dst tensor auto initialization if not yet initialized + auto_init_if_empty(*dst, lhs->clone()->set_tensor_shape(misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info))); - const int m = output->dimension(1); - const int n = output->dimension(0); + const int m = dst->dimension(1); + const int n = dst->dimension(0); const int k = matmul_kernel_info.adj_lhs ? lhs->tensor_shape().y() : lhs->tensor_shape().x(); const bool adj_lhs = matmul_kernel_info.adj_lhs; @@ -157,7 +158,7 @@ void ClMatMulNativeKernel::configure(const ClCompileContext &compile_context, IT _export_rhs_to_cl_image = matmul_kernel_info.export_rhs_to_cl_image && !rhs->lock_paddings(); // Configure kernel window - Window win = calculate_max_window(*output, Steps(n0, m0)); + Window win = calculate_max_window(*dst, Steps(n0, m0)); win = win.collapse(win, Window::DimZ); IClKernel::configure_internal(win); @@ -201,7 +202,7 @@ void ClMatMulNativeKernel::configure(const ClCompileContext &compile_context, IT _config_id += "_"; _config_id += support::cpp11::to_string(k); _config_id += "_"; - _config_id += support::cpp11::to_string(output->dimension(2)); + _config_id += support::cpp11::to_string(dst->dimension(2)); _config_id += "_"; _config_id += support::cpp11::to_string(_export_rhs_to_cl_image); _config_id += "_"; @@ -219,9 +220,9 @@ void ClMatMulNativeKernel::run_op(ITensorPack &tensors, const Window &window, cl const ICLTensor *lhs = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_0)); const ICLTensor *rhs = utils::cast::polymorphic_downcast(tensors.get_const_tensor(TensorType::ACL_SRC_1)); - ICLTensor *output = utils::cast::polymorphic_downcast(tensors.get_tensor(TensorType::ACL_DST)); - ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, output); - ARM_COMPUTE_LOG_PARAMS(lhs, rhs, output); + ICLTensor *dst = utils::cast::polymorphic_downcast(tensors.get_tensor(TensorType::ACL_DST)); + ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, dst); + ARM_COMPUTE_LOG_PARAMS(lhs, rhs, dst); unsigned int idx = 0; Window window_collapsed = window.collapse(ICLKernel::window(), Window::DimZ); @@ -242,7 +243,7 @@ void ClMatMulNativeKernel::run_op(ITensorPack &tensors, const Window &window, cl } add_3d_tensor_nhw_argument(idx, rhs); - add_3d_tensor_nhw_argument(idx, output); + add_3d_tensor_nhw_argument(idx, dst); enqueue(queue, *this, window_collapsed, lws_hint()); } -- cgit v1.2.1