diff options
Diffstat (limited to 'src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp')
-rw-r--r-- | src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp | 90 |
1 files changed, 14 insertions, 76 deletions
diff --git a/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp b/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp index 44c720a40c..2420ad6a78 100644 --- a/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp +++ b/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp @@ -28,13 +28,14 @@ #include "arm_compute/core/ITensorPack.h" #include "arm_compute/core/KernelDescriptors.h" #include "arm_compute/core/TensorInfo.h" +#include "arm_compute/core/utils/StringUtils.h" #include "arm_compute/core/utils/helpers/AdjustVecSize.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" -#include "arm_compute/core/utils/StringUtils.h" #include "src/common/utils/Log.h" #include "src/core/helpers/AutoConfiguration.h" #include "src/core/helpers/WindowHelpers.h" +#include "src/gpu/cl/kernels/helpers/MatMulKernelHelpers.h" #include "support/Cast.h" #include "support/StringSupport.h" @@ -52,13 +53,6 @@ constexpr int mmul_m0 = 4; constexpr int mmul_n0 = 4; constexpr int mmul_k0 = 4; -inline std::pair<int, int> adjust_m0_n0(int m0, int n0, int m, int n) -{ - m0 = std::min(m0, m); - n0 = adjust_vec_size(n0, n); - return { m0, n0 }; -} - Status validate_matmul_kernel_info(const MatMulKernelInfo &matmul_kernel_info) { const bool adj_lhs = matmul_kernel_info.adj_lhs; @@ -83,68 +77,6 @@ Status validate_matmul_kernel_info(const MatMulKernelInfo &matmul_kernel_info) return Status{}; } - -Status validate_input_shapes(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const MatMulKernelInfo &matmul_kernel_info) -{ - const size_t lhs_k = matmul_kernel_info.adj_lhs ? lhs_shape.y() : lhs_shape.x(); - const size_t rhs_k = matmul_kernel_info.adj_rhs ? rhs_shape.x() : rhs_shape.y(); - - ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_k != rhs_k, "K dimension in Lhs and Rhs matrices must match."); - ARM_COMPUTE_RETURN_ERROR_ON_MSG_VAR((lhs_k % mmul_k0) != 0, "K dimension must be a multiple of %d", mmul_k0); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_shape.total_size() == 0, "Lhs tensor can't be empty"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs_shape.total_size() == 0, "Rhs tensor can't be empty"); - - constexpr size_t batch_dim_start = 2; - for(size_t i = batch_dim_start; i < Coordinates::num_max_dimensions; ++i) - { - ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_shape[i] != rhs_shape[i], "Batch dimension broadcasting is not supported"); - } - - return Status{}; -} - -std::pair<Status, Window> validate_and_configure_window(ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulKernelInfo &matmul_kernel_info) -{ - ARM_COMPUTE_UNUSED(lhs, rhs); - - const Window win = calculate_max_window(*dst, Steps(1, 1)); - - // Collapse along the Z direction - // This collapse needs to be here in order to tune the Z dimension of LWS - Window collapsed = win.collapse(win, Window::DimZ); - - // Reconfigure window size, one arm_matrix_multiply call needs 16 threads to finish. - Window::Dimension x_dimension = collapsed.x(); - Window::Dimension y_dimension = collapsed.y(); - - const int m = dst->dimension(1); - const int n = dst->dimension(0); - - int m0{}; - int n0{}; - std::tie(m0, n0) = adjust_m0_n0(matmul_kernel_info.m0, matmul_kernel_info.n0, m, n); - - // Make M and N multiple of M0 and N0 respectively - const unsigned int ceil_to_multiple_n_n0 = ceil_to_multiple(n, n0); - const unsigned int ceil_to_multiple_m_m0 = ceil_to_multiple(m, m0); - - // Divide M and N by M0 and N0 respectively - const unsigned int n_div_n0 = ceil_to_multiple_n_n0 / n0; - const unsigned int m_div_m0 = ceil_to_multiple_m_m0 / m0; - - // Make n_div_n0 and m_div_m0 multiple of mmul_n0 and mmul_m0 respectively - const unsigned int ceil_to_multiple_n_div_n0_mmul_n0 = ceil_to_multiple(n_div_n0, mmul_n0); - const unsigned int ceil_to_multiple_m_div_m0_mmul_m0 = ceil_to_multiple(m_div_m0, mmul_m0); - - // Ensure x_dimension is multiple of MMUL block size (mmul_m0 * mmul_n0) - x_dimension.set_end(ceil_to_multiple_n_div_n0_mmul_n0 * mmul_m0); - y_dimension.set_end(ceil_to_multiple_m_div_m0_mmul_m0 / mmul_m0); - - collapsed.set(Window::DimX, x_dimension); - collapsed.set(Window::DimY, y_dimension); - - return std::make_pair(Status{}, collapsed); -} } ClMatMulNativeMMULKernel::ClMatMulNativeMMULKernel() { @@ -158,9 +90,14 @@ Status ClMatMulNativeMMULKernel::validate(const ITensorInfo *lhs, const ITensorI ARM_COMPUTE_RETURN_ERROR_ON_MSG(!arm_matrix_multiply_supported(CLKernelLibrary::get().get_device()), "The extension cl_arm_matrix_multiply is not supported on the target platform"); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs); ARM_COMPUTE_RETURN_ON_ERROR(validate_matmul_kernel_info(matmul_kernel_info)); - ARM_COMPUTE_RETURN_ON_ERROR(validate_input_shapes(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)); - const TensorShape expected_output_shape = misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info); + const TensorShape &lhs_shape = lhs->tensor_shape(); + ARM_COMPUTE_RETURN_ON_ERROR(validate_matmul_input_shapes(lhs_shape, rhs->tensor_shape(), matmul_kernel_info)); + + const size_t lhs_k = matmul_kernel_info.adj_lhs ? lhs_shape.y() : lhs_shape.x(); + ARM_COMPUTE_RETURN_ERROR_ON_MSG_VAR((lhs_k % mmul_k0) != 0, "K dimension must be a multiple of %d", mmul_k0); + + const TensorShape expected_output_shape = misc::shape_calculator::compute_matmul_shape(lhs_shape, rhs->tensor_shape(), matmul_kernel_info); if(dst->total_size() != 0) { @@ -195,12 +132,13 @@ void ClMatMulNativeMMULKernel::configure(const ClCompileContext &compile_context _n = n; _k = k; - int m0{}; - int n0{}; - std::tie(m0, n0) = adjust_m0_n0(matmul_kernel_info.m0, matmul_kernel_info.n0, m, n); + const int m0 = std::min(matmul_kernel_info.m0, m); + const int n0 = adjust_vec_size(matmul_kernel_info.n0, n); // Configure kernel window - const auto win_config = validate_and_configure_window(lhs, rhs, dst, matmul_kernel_info); + const auto win_config = validate_and_configure_window_for_mmul_kernels(lhs, rhs, dst, matmul_kernel_info, mmul_m0, + mmul_n0); + ARM_COMPUTE_ERROR_THROW_ON(win_config.first); IClKernel::configure_internal(win_config.second); |