diff options
author | Gunes Bayir <gunes.bayir@arm.com> | 2023-03-17 13:52:21 +0000 |
---|---|---|
committer | Gunes Bayir <gunes.bayir@arm.com> | 2023-03-20 14:49:51 +0000 |
commit | 8918b23073851417e8be6e5e53c6380dbdedf201 (patch) | |
tree | ad0eb38aa7086adb71a444802009a04de3e34929 /src/gpu/cl | |
parent | 14d7b535d48620f009efca576cc70fb6ea9ff20d (diff) | |
download | ComputeLibrary-8918b23073851417e8be6e5e53c6380dbdedf201.tar.gz |
Implement OpenCL MatMul for Lhs T Rhs T/NT FP32/16
- Implement opencl kernel for LHS transposed and RHS non-transposed
- Implement opencl kernel for LHS transposed and RHS transposed
- Add validation tests
Resolves: COMPMID-5953, COMPMID-5955
Change-Id: I55589acbffe86c44e29807574975978a1ec09bad
Signed-off-by: Gunes Bayir <gunes.bayir@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9345
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/gpu/cl')
-rw-r--r-- | src/gpu/cl/ClKernelLibrary.cpp | 6 | ||||
-rw-r--r-- | src/gpu/cl/kernels/ClNativeMatMulKernel.cpp | 59 |
2 files changed, 37 insertions, 28 deletions
diff --git a/src/gpu/cl/ClKernelLibrary.cpp b/src/gpu/cl/ClKernelLibrary.cpp index 8099071fcd..44b086f2fc 100644 --- a/src/gpu/cl/ClKernelLibrary.cpp +++ b/src/gpu/cl/ClKernelLibrary.cpp @@ -319,6 +319,10 @@ const std::map<std::string, std::string> ClKernelLibrary::_kernel_program_map = { "l2_normalize_x", "common/l2_normalize.cl" }, { "l2_normalize_y", "common/l2_normalize.cl" }, { "l2_normalize_z", "common/l2_normalize.cl" }, + { "mat_mul_native_nt_nt", "common/mat_mul.cl" }, + { "mat_mul_native_nt_t", "common/mat_mul.cl" }, + { "mat_mul_native_t_nt", "common/mat_mul.cl" }, + { "mat_mul_native_t_t", "common/mat_mul.cl" }, { "max_unpooling_layer_2", "common/unpooling_layer.cl" }, { "mean_stddev_normalization", "common/mean_stddev_normalization.cl" }, { "memset", "common/memset.cl" }, @@ -359,8 +363,6 @@ const std::map<std::string, std::string> ClKernelLibrary::_kernel_program_map = { "strided_slice", "common/slice_ops.cl" }, { "tile", "common/tile.cl" }, { "transpose", "common/transpose.cl" }, - { "mat_mul_native_nt_nt", "common/mat_mul.cl" }, - { "mat_mul_native_nt_t", "common/mat_mul.cl" }, #ifdef ENABLE_NCHW_KERNELS { "batch_to_space_nchw", "nchw/batch_to_space.cl" }, { "batch_to_space_static_nchw", "nchw/batch_to_space.cl" }, diff --git a/src/gpu/cl/kernels/ClNativeMatMulKernel.cpp b/src/gpu/cl/kernels/ClNativeMatMulKernel.cpp index 6a4db65922..ffbaf49c02 100644 --- a/src/gpu/cl/kernels/ClNativeMatMulKernel.cpp +++ b/src/gpu/cl/kernels/ClNativeMatMulKernel.cpp @@ -50,28 +50,40 @@ Status validate_matmul_kernel_info(const MatMulKernelInfo &matmul_kernel_info) const int k0 = matmul_kernel_info.k0; // Validate M0 - if(!adj_lhs) - { - // We support any positive integer, but will test & benchmark only 1 to 8 because > 8 will not efficient - ARM_COMPUTE_RETURN_ERROR_ON_MSG(m0 < 1, "Only positive integers are supported for M0 for Lhs non-transposed"); - } - else + ARM_COMPUTE_RETURN_ERROR_ON_MSG(m0 < 1, "Only positive integers are supported for M0"); + + if(adj_lhs) { - ARM_COMPUTE_RETURN_ERROR_ON_MSG((m0 & (m0 - 1)) && (m0 != 3) && (m0 > 16), "Only 1,2,3,4,8,16 are supported for N0 for Lhs transposed"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(((m0 & (m0 - 1)) && (m0 != 3)) || (m0 > 16), "Only 1,2,3,4,8,16 are supported for N0 for Lhs transposed"); } // Validate N0 - ARM_COMPUTE_RETURN_ERROR_ON_MSG((n0 & (n0 - 1)) && (n0 != 3) && (n0 > 16), "Only 1,2,3,4,8,16 are supported for N0"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(n0 < 1, "Only positive integers are supported for N0"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(((n0 & (n0 - 1)) && (n0 != 3)) || (n0 > 16), "Only 1,2,3,4,8,16 are supported for N0"); // Validate K0 - if(adj_lhs && !adj_rhs) + ARM_COMPUTE_RETURN_ERROR_ON_MSG(k0 < 1, "Only positive integers are supported for K0"); + if(!adj_lhs || adj_rhs) { - // We support any positive integer, but will test & benchmark only 1 to 8 because > 8 will not efficient - ARM_COMPUTE_RETURN_ERROR_ON_MSG(k0 < 1, "Only positive integers are supported for K0 for Lhs transposed & Rhs non-transposed"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(((k0 & (k0 - 1)) && (k0 != 3)) || (k0 > 16), "Only 1,2,3,4,8,16 are supported for K0"); } - else + + return Status{}; +} + +Status validate_input_shapes(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const MatMulKernelInfo &matmul_kernel_info) +{ + const size_t lhs_k = matmul_kernel_info.adj_lhs ? lhs_shape.y() : lhs_shape.x(); + const size_t rhs_k = matmul_kernel_info.adj_rhs ? rhs_shape.x() : rhs_shape.y(); + + ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_k != rhs_k, "K dimension in Lhs and Rhs matrices must match."); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_shape.total_size() == 0, "Lhs tensor can't be empty"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs_shape.total_size() == 0, "Rhs tensor can't be empty"); + + constexpr size_t batch_dim_start = 2; + for(size_t i = batch_dim_start; i < Coordinates::num_max_dimensions; ++i) { - ARM_COMPUTE_RETURN_ERROR_ON_MSG((k0 & (k0 - 1)) && (k0 != 3) && (k0 > 16), "Only 1,2,3,4,8,16 are supported for K0"); + ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_shape[i] != rhs_shape[i], "Batch dimension broadcasting is not supported"); } return Status{}; @@ -87,15 +99,14 @@ Status ClNativeMatMulKernel::validate(const ITensorInfo *lhs, const ITensorInfo ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs); ARM_COMPUTE_RETURN_ON_ERROR(validate_matmul_kernel_info(matmul_kernel_info)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_input_shapes(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)); if(output->total_size() != 0) { - const TensorInfo tensor_info_output = output->clone()->set_tensor_shape(misc::shape_calculator::compute_batchmatmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)); + const TensorInfo tensor_info_output = output->clone()->set_tensor_shape(misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_output); ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, output); } - ARM_COMPUTE_RETURN_ERROR_ON_MSG(matmul_kernel_info.adj_lhs && matmul_kernel_info.adj_rhs, "LHS T and RHS T not implemented"); - ARM_COMPUTE_RETURN_ERROR_ON_MSG(matmul_kernel_info.adj_lhs && !matmul_kernel_info.adj_rhs, "LHS T and RHS NT not implemented"); return Status{}; } @@ -105,14 +116,15 @@ void ClNativeMatMulKernel::configure(const ClCompileContext &compile_context, IT ARM_COMPUTE_LOG_PARAMS(lhs, rhs, output, matmul_kernel_info); // output tensor auto initialization if not yet initialized - auto_init_if_empty(*output, lhs->clone()->set_tensor_shape(misc::shape_calculator::compute_batchmatmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info))); + auto_init_if_empty(*output, lhs->clone()->set_tensor_shape(misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info))); ARM_COMPUTE_ERROR_THROW_ON(validate(lhs, rhs, output, matmul_kernel_info)); - const int m = output->dimension(1); - const int n = output->dimension(0); - const int k = matmul_kernel_info.adj_lhs ? lhs->tensor_shape().y() : lhs->tensor_shape().x(); + const int m = output->dimension(1); + const int n = output->dimension(0); + const int k = matmul_kernel_info.adj_lhs ? lhs->tensor_shape().y() : lhs->tensor_shape().x(); + const bool adj_lhs = matmul_kernel_info.adj_lhs; - int m0 = std::min(matmul_kernel_info.m0, m); + int m0 = adj_lhs ? adjust_vec_size(matmul_kernel_info.m0, m) : std::min(matmul_kernel_info.m0, m); int n0 = adjust_vec_size(matmul_kernel_info.n0, n); // Configure kernel window @@ -137,11 +149,6 @@ void ClNativeMatMulKernel::configure(const ClCompileContext &compile_context, IT kernel_name += matmul_kernel_info.adj_lhs ? "_t" : "_nt"; kernel_name += matmul_kernel_info.adj_rhs ? "_t" : "_nt"; - if(matmul_kernel_info.adj_lhs) - { - ARM_COMPUTE_ERROR("Only Implemented LHS non-transposed kernels"); - } - // A macro guard to compile ONLY the kernel of interest build_opts.add_option("-D" + upper_string(kernel_name)); |