aboutsummaryrefslogtreecommitdiff
path: root/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp')
-rw-r--r--src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp90
1 files changed, 14 insertions, 76 deletions
diff --git a/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp b/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp
index 44c720a40c..2420ad6a78 100644
--- a/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp
+++ b/src/gpu/cl/kernels/ClMatMulNativeMMULKernel.cpp
@@ -28,13 +28,14 @@
#include "arm_compute/core/ITensorPack.h"
#include "arm_compute/core/KernelDescriptors.h"
#include "arm_compute/core/TensorInfo.h"
+#include "arm_compute/core/utils/StringUtils.h"
#include "arm_compute/core/utils/helpers/AdjustVecSize.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
-#include "arm_compute/core/utils/StringUtils.h"
#include "src/common/utils/Log.h"
#include "src/core/helpers/AutoConfiguration.h"
#include "src/core/helpers/WindowHelpers.h"
+#include "src/gpu/cl/kernels/helpers/MatMulKernelHelpers.h"
#include "support/Cast.h"
#include "support/StringSupport.h"
@@ -52,13 +53,6 @@ constexpr int mmul_m0 = 4;
constexpr int mmul_n0 = 4;
constexpr int mmul_k0 = 4;
-inline std::pair<int, int> adjust_m0_n0(int m0, int n0, int m, int n)
-{
- m0 = std::min(m0, m);
- n0 = adjust_vec_size(n0, n);
- return { m0, n0 };
-}
-
Status validate_matmul_kernel_info(const MatMulKernelInfo &matmul_kernel_info)
{
const bool adj_lhs = matmul_kernel_info.adj_lhs;
@@ -83,68 +77,6 @@ Status validate_matmul_kernel_info(const MatMulKernelInfo &matmul_kernel_info)
return Status{};
}
-
-Status validate_input_shapes(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const MatMulKernelInfo &matmul_kernel_info)
-{
- const size_t lhs_k = matmul_kernel_info.adj_lhs ? lhs_shape.y() : lhs_shape.x();
- const size_t rhs_k = matmul_kernel_info.adj_rhs ? rhs_shape.x() : rhs_shape.y();
-
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_k != rhs_k, "K dimension in Lhs and Rhs matrices must match.");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG_VAR((lhs_k % mmul_k0) != 0, "K dimension must be a multiple of %d", mmul_k0);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_shape.total_size() == 0, "Lhs tensor can't be empty");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs_shape.total_size() == 0, "Rhs tensor can't be empty");
-
- constexpr size_t batch_dim_start = 2;
- for(size_t i = batch_dim_start; i < Coordinates::num_max_dimensions; ++i)
- {
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_shape[i] != rhs_shape[i], "Batch dimension broadcasting is not supported");
- }
-
- return Status{};
-}
-
-std::pair<Status, Window> validate_and_configure_window(ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulKernelInfo &matmul_kernel_info)
-{
- ARM_COMPUTE_UNUSED(lhs, rhs);
-
- const Window win = calculate_max_window(*dst, Steps(1, 1));
-
- // Collapse along the Z direction
- // This collapse needs to be here in order to tune the Z dimension of LWS
- Window collapsed = win.collapse(win, Window::DimZ);
-
- // Reconfigure window size, one arm_matrix_multiply call needs 16 threads to finish.
- Window::Dimension x_dimension = collapsed.x();
- Window::Dimension y_dimension = collapsed.y();
-
- const int m = dst->dimension(1);
- const int n = dst->dimension(0);
-
- int m0{};
- int n0{};
- std::tie(m0, n0) = adjust_m0_n0(matmul_kernel_info.m0, matmul_kernel_info.n0, m, n);
-
- // Make M and N multiple of M0 and N0 respectively
- const unsigned int ceil_to_multiple_n_n0 = ceil_to_multiple(n, n0);
- const unsigned int ceil_to_multiple_m_m0 = ceil_to_multiple(m, m0);
-
- // Divide M and N by M0 and N0 respectively
- const unsigned int n_div_n0 = ceil_to_multiple_n_n0 / n0;
- const unsigned int m_div_m0 = ceil_to_multiple_m_m0 / m0;
-
- // Make n_div_n0 and m_div_m0 multiple of mmul_n0 and mmul_m0 respectively
- const unsigned int ceil_to_multiple_n_div_n0_mmul_n0 = ceil_to_multiple(n_div_n0, mmul_n0);
- const unsigned int ceil_to_multiple_m_div_m0_mmul_m0 = ceil_to_multiple(m_div_m0, mmul_m0);
-
- // Ensure x_dimension is multiple of MMUL block size (mmul_m0 * mmul_n0)
- x_dimension.set_end(ceil_to_multiple_n_div_n0_mmul_n0 * mmul_m0);
- y_dimension.set_end(ceil_to_multiple_m_div_m0_mmul_m0 / mmul_m0);
-
- collapsed.set(Window::DimX, x_dimension);
- collapsed.set(Window::DimY, y_dimension);
-
- return std::make_pair(Status{}, collapsed);
-}
}
ClMatMulNativeMMULKernel::ClMatMulNativeMMULKernel()
{
@@ -158,9 +90,14 @@ Status ClMatMulNativeMMULKernel::validate(const ITensorInfo *lhs, const ITensorI
ARM_COMPUTE_RETURN_ERROR_ON_MSG(!arm_matrix_multiply_supported(CLKernelLibrary::get().get_device()), "The extension cl_arm_matrix_multiply is not supported on the target platform");
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs);
ARM_COMPUTE_RETURN_ON_ERROR(validate_matmul_kernel_info(matmul_kernel_info));
- ARM_COMPUTE_RETURN_ON_ERROR(validate_input_shapes(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info));
- const TensorShape expected_output_shape = misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info);
+ const TensorShape &lhs_shape = lhs->tensor_shape();
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_matmul_input_shapes(lhs_shape, rhs->tensor_shape(), matmul_kernel_info));
+
+ const size_t lhs_k = matmul_kernel_info.adj_lhs ? lhs_shape.y() : lhs_shape.x();
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG_VAR((lhs_k % mmul_k0) != 0, "K dimension must be a multiple of %d", mmul_k0);
+
+ const TensorShape expected_output_shape = misc::shape_calculator::compute_matmul_shape(lhs_shape, rhs->tensor_shape(), matmul_kernel_info);
if(dst->total_size() != 0)
{
@@ -195,12 +132,13 @@ void ClMatMulNativeMMULKernel::configure(const ClCompileContext &compile_context
_n = n;
_k = k;
- int m0{};
- int n0{};
- std::tie(m0, n0) = adjust_m0_n0(matmul_kernel_info.m0, matmul_kernel_info.n0, m, n);
+ const int m0 = std::min(matmul_kernel_info.m0, m);
+ const int n0 = adjust_vec_size(matmul_kernel_info.n0, n);
// Configure kernel window
- const auto win_config = validate_and_configure_window(lhs, rhs, dst, matmul_kernel_info);
+ const auto win_config = validate_and_configure_window_for_mmul_kernels(lhs, rhs, dst, matmul_kernel_info, mmul_m0,
+ mmul_n0);
+
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
IClKernel::configure_internal(win_config.second);