aboutsummaryrefslogtreecommitdiff
path: root/src/gpu
diff options
context:
space:
mode:
authorGunes Bayir <gunes.bayir@arm.com>2023-03-17 13:52:21 +0000
committerGunes Bayir <gunes.bayir@arm.com>2023-03-20 14:49:51 +0000
commit8918b23073851417e8be6e5e53c6380dbdedf201 (patch)
treead0eb38aa7086adb71a444802009a04de3e34929 /src/gpu
parent14d7b535d48620f009efca576cc70fb6ea9ff20d (diff)
downloadComputeLibrary-8918b23073851417e8be6e5e53c6380dbdedf201.tar.gz
Implement OpenCL MatMul for Lhs T Rhs T/NT FP32/16
- Implement opencl kernel for LHS transposed and RHS non-transposed - Implement opencl kernel for LHS transposed and RHS transposed - Add validation tests Resolves: COMPMID-5953, COMPMID-5955 Change-Id: I55589acbffe86c44e29807574975978a1ec09bad Signed-off-by: Gunes Bayir <gunes.bayir@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9345 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/gpu')
-rw-r--r--src/gpu/cl/ClKernelLibrary.cpp6
-rw-r--r--src/gpu/cl/kernels/ClNativeMatMulKernel.cpp59
2 files changed, 37 insertions, 28 deletions
diff --git a/src/gpu/cl/ClKernelLibrary.cpp b/src/gpu/cl/ClKernelLibrary.cpp
index 8099071fcd..44b086f2fc 100644
--- a/src/gpu/cl/ClKernelLibrary.cpp
+++ b/src/gpu/cl/ClKernelLibrary.cpp
@@ -319,6 +319,10 @@ const std::map<std::string, std::string> ClKernelLibrary::_kernel_program_map =
{ "l2_normalize_x", "common/l2_normalize.cl" },
{ "l2_normalize_y", "common/l2_normalize.cl" },
{ "l2_normalize_z", "common/l2_normalize.cl" },
+ { "mat_mul_native_nt_nt", "common/mat_mul.cl" },
+ { "mat_mul_native_nt_t", "common/mat_mul.cl" },
+ { "mat_mul_native_t_nt", "common/mat_mul.cl" },
+ { "mat_mul_native_t_t", "common/mat_mul.cl" },
{ "max_unpooling_layer_2", "common/unpooling_layer.cl" },
{ "mean_stddev_normalization", "common/mean_stddev_normalization.cl" },
{ "memset", "common/memset.cl" },
@@ -359,8 +363,6 @@ const std::map<std::string, std::string> ClKernelLibrary::_kernel_program_map =
{ "strided_slice", "common/slice_ops.cl" },
{ "tile", "common/tile.cl" },
{ "transpose", "common/transpose.cl" },
- { "mat_mul_native_nt_nt", "common/mat_mul.cl" },
- { "mat_mul_native_nt_t", "common/mat_mul.cl" },
#ifdef ENABLE_NCHW_KERNELS
{ "batch_to_space_nchw", "nchw/batch_to_space.cl" },
{ "batch_to_space_static_nchw", "nchw/batch_to_space.cl" },
diff --git a/src/gpu/cl/kernels/ClNativeMatMulKernel.cpp b/src/gpu/cl/kernels/ClNativeMatMulKernel.cpp
index 6a4db65922..ffbaf49c02 100644
--- a/src/gpu/cl/kernels/ClNativeMatMulKernel.cpp
+++ b/src/gpu/cl/kernels/ClNativeMatMulKernel.cpp
@@ -50,28 +50,40 @@ Status validate_matmul_kernel_info(const MatMulKernelInfo &matmul_kernel_info)
const int k0 = matmul_kernel_info.k0;
// Validate M0
- if(!adj_lhs)
- {
- // We support any positive integer, but will test & benchmark only 1 to 8 because > 8 will not efficient
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(m0 < 1, "Only positive integers are supported for M0 for Lhs non-transposed");
- }
- else
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(m0 < 1, "Only positive integers are supported for M0");
+
+ if(adj_lhs)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((m0 & (m0 - 1)) && (m0 != 3) && (m0 > 16), "Only 1,2,3,4,8,16 are supported for N0 for Lhs transposed");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(((m0 & (m0 - 1)) && (m0 != 3)) || (m0 > 16), "Only 1,2,3,4,8,16 are supported for N0 for Lhs transposed");
}
// Validate N0
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((n0 & (n0 - 1)) && (n0 != 3) && (n0 > 16), "Only 1,2,3,4,8,16 are supported for N0");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(n0 < 1, "Only positive integers are supported for N0");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(((n0 & (n0 - 1)) && (n0 != 3)) || (n0 > 16), "Only 1,2,3,4,8,16 are supported for N0");
// Validate K0
- if(adj_lhs && !adj_rhs)
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(k0 < 1, "Only positive integers are supported for K0");
+ if(!adj_lhs || adj_rhs)
{
- // We support any positive integer, but will test & benchmark only 1 to 8 because > 8 will not efficient
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(k0 < 1, "Only positive integers are supported for K0 for Lhs transposed & Rhs non-transposed");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(((k0 & (k0 - 1)) && (k0 != 3)) || (k0 > 16), "Only 1,2,3,4,8,16 are supported for K0");
}
- else
+
+ return Status{};
+}
+
+Status validate_input_shapes(const TensorShape &lhs_shape, const TensorShape &rhs_shape, const MatMulKernelInfo &matmul_kernel_info)
+{
+ const size_t lhs_k = matmul_kernel_info.adj_lhs ? lhs_shape.y() : lhs_shape.x();
+ const size_t rhs_k = matmul_kernel_info.adj_rhs ? rhs_shape.x() : rhs_shape.y();
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_k != rhs_k, "K dimension in Lhs and Rhs matrices must match.");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_shape.total_size() == 0, "Lhs tensor can't be empty");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs_shape.total_size() == 0, "Rhs tensor can't be empty");
+
+ constexpr size_t batch_dim_start = 2;
+ for(size_t i = batch_dim_start; i < Coordinates::num_max_dimensions; ++i)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((k0 & (k0 - 1)) && (k0 != 3) && (k0 > 16), "Only 1,2,3,4,8,16 are supported for K0");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_shape[i] != rhs_shape[i], "Batch dimension broadcasting is not supported");
}
return Status{};
@@ -87,15 +99,14 @@ Status ClNativeMatMulKernel::validate(const ITensorInfo *lhs, const ITensorInfo
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs);
ARM_COMPUTE_RETURN_ON_ERROR(validate_matmul_kernel_info(matmul_kernel_info));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_input_shapes(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info));
if(output->total_size() != 0)
{
- const TensorInfo tensor_info_output = output->clone()->set_tensor_shape(misc::shape_calculator::compute_batchmatmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info));
+ const TensorInfo tensor_info_output = output->clone()->set_tensor_shape(misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info));
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(output, &tensor_info_output);
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, output);
}
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(matmul_kernel_info.adj_lhs && matmul_kernel_info.adj_rhs, "LHS T and RHS T not implemented");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(matmul_kernel_info.adj_lhs && !matmul_kernel_info.adj_rhs, "LHS T and RHS NT not implemented");
return Status{};
}
@@ -105,14 +116,15 @@ void ClNativeMatMulKernel::configure(const ClCompileContext &compile_context, IT
ARM_COMPUTE_LOG_PARAMS(lhs, rhs, output, matmul_kernel_info);
// output tensor auto initialization if not yet initialized
- auto_init_if_empty(*output, lhs->clone()->set_tensor_shape(misc::shape_calculator::compute_batchmatmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)));
+ auto_init_if_empty(*output, lhs->clone()->set_tensor_shape(misc::shape_calculator::compute_matmul_shape(lhs->tensor_shape(), rhs->tensor_shape(), matmul_kernel_info)));
ARM_COMPUTE_ERROR_THROW_ON(validate(lhs, rhs, output, matmul_kernel_info));
- const int m = output->dimension(1);
- const int n = output->dimension(0);
- const int k = matmul_kernel_info.adj_lhs ? lhs->tensor_shape().y() : lhs->tensor_shape().x();
+ const int m = output->dimension(1);
+ const int n = output->dimension(0);
+ const int k = matmul_kernel_info.adj_lhs ? lhs->tensor_shape().y() : lhs->tensor_shape().x();
+ const bool adj_lhs = matmul_kernel_info.adj_lhs;
- int m0 = std::min(matmul_kernel_info.m0, m);
+ int m0 = adj_lhs ? adjust_vec_size(matmul_kernel_info.m0, m) : std::min(matmul_kernel_info.m0, m);
int n0 = adjust_vec_size(matmul_kernel_info.n0, n);
// Configure kernel window
@@ -137,11 +149,6 @@ void ClNativeMatMulKernel::configure(const ClCompileContext &compile_context, IT
kernel_name += matmul_kernel_info.adj_lhs ? "_t" : "_nt";
kernel_name += matmul_kernel_info.adj_rhs ? "_t" : "_nt";
- if(matmul_kernel_info.adj_lhs)
- {
- ARM_COMPUTE_ERROR("Only Implemented LHS non-transposed kernels");
- }
-
// A macro guard to compile ONLY the kernel of interest
build_opts.add_option("-D" + upper_string(kernel_name));